From e4c1559d32084df998a82df89662824e361ad696 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 27 Mar 2018 00:20:03 -0700 Subject: [PATCH 1/4] fix --- .../optimizer/complexTypesSuite.scala | 178 +++++++++++++----- .../apache/spark/sql/ComplexTypesSuite.scala | 109 +++++++++++ 2 files changed, 235 insertions(+), 52 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index e44a6692ad8e..d75439c5bf22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -47,10 +47,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { SimplifyExtractValueOps) :: Nil } - val idAtt = ('id).long.notNull - val nullableIdAtt = ('nullable_id).long + private val idAtt = ('id).long.notNull + private val nullableIdAtt = ('nullable_id).long - lazy val relation = LocalRelation(idAtt, nullableIdAtt) + private val relation = LocalRelation(idAtt, nullableIdAtt) + private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int) + + private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { + val optimized = Optimizer.execute(originalQuery.analyze) + assert(optimized.resolved, "optimized plans must be still resolvable") + comparePlans(optimized, correctAnswer.analyze) + } test("explicit get from namedStruct") { val query = relation @@ -58,31 +65,28 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField( CreateNamedStruct(Seq("att", 'id )), 0, - None) as "outerAtt").analyze - val expected = relation.select('id as "outerAtt").analyze + None) as "outerAtt") + val expected = relation.select('id as "outerAtt") - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("explicit get from named_struct- expression maintains original deduced alias") { val query = relation .select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None)) - .analyze val expected = relation .select('id as "named_struct(att, id).att") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed getStructField ontop of namedStruct") { val query = relation .select(CreateNamedStruct(Seq("att", 'id)) as "struct1") .select(GetStructField('struct1, 0, None) as "struct1Att") - .analyze - val expected = relation.select('id as "struct1Att").analyze - comparePlans(Optimizer execute query, expected) + val expected = relation.select('id as "struct1Att") + checkRule(query, expected) } test("collapse multiple CreateNamedStruct/GetStructField pairs") { @@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetStructField('struct1, 0, None) as "struct1Att1", GetStructField('struct1, 1, None) as "struct1Att2") - .analyze val expected = relation. select( 'id as "struct1Att1", ('id * 'id) as "struct1Att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed2 - deduced names") { @@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetStructField('struct1, 0, None), GetStructField('struct1, 1, None)) - .analyze val expected = relation. select( 'id as "struct1.att1", ('id * 'id) as "struct1.att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplified array ops") { @@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { 1, false), 1) as "a4") - .analyze val expected = relation .select( @@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { "att2", (('id + 1L) * ('id + 1L)))) as "a2", ('id + 1L) as "a3", ('id + 1L) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("SPARK-22570: CreateArray should not create a lot of global variables") { @@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField(GetMapValue('m, "r1"), 0, None) as "a2", GetMapValue('m, "r32") as "a3", GetStructField(GetMapValue('m, "r32"), 0, None) as "a4") - .analyze val expected = relation.select( @@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ) ) as "a3", Literal.create(null, LongType) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, constant lookup, dynamic keys") { @@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo(13L, ('id + 1L)), ('id + 2L)), (EqualTo(13L, ('id + 2L)), ('id + 3L)), (Literal(true), 'id))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") { @@ -240,7 +234,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), ('id + 3L)) as "a") - .analyze val expected = relation .select( CaseWhen(Seq( @@ -248,8 +241,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo('id + 3L, ('id + 1L)), ('id + 2L)), (EqualTo('id + 3L, ('id + 2L)), ('id + 3L)), (Literal(true), ('id + 4L)))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, no positive match") { @@ -263,7 +255,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 'id + 30L) as "a") - .analyze val expected = relation.select( CaseWhen(Seq( (EqualTo('id + 30L, 'id), ('id + 1L)), @@ -271,8 +262,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo('id + 30L, ('id + 2L)), ('id + 3L)), (EqualTo('id + 30L, ('id + 3L)), ('id + 4L)), (EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, constant lookup, mixed keys, eliminated constants") { @@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 2L), ('id + 3L), ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, potential dynamic match with null value + an absolute constant match") { @@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 2L ) as "a") - .analyze val expected = relation .select( @@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // but it cannot override a potential match with ('id + 2L), // which is exactly what [[Coalesce]] would do in this case. (Literal.TrueLiteral, 'id))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) + } + + test("SPARK-23500: Simplify array ops that are not at the top node") { + val query = LocalRelation('id.long) + .select( + CreateArray(Seq( + CreateNamedStruct(Seq( + "att1", 'id, + "att2", 'id * 'id)), + CreateNamedStruct(Seq( + "att1", 'id + 1, + "att2", ('id + 1) * ('id + 1)) + )) + ) as "arr") + .select( + GetStructField(GetArrayItem('arr, 1), 0, None) as "a1", + GetArrayItem( + GetArrayStructFields('arr, + StructField("att1", LongType, nullable = false), + ordinal = 0, + numFields = 1, + containsNull = false), + ordinal = 1) as "a2") + .where('id > 0L) + + val expected = LocalRelation('id.long) + .select( + ('id + 1L) as "a1", + ('id + 1L) as "a2") + .where('id > 0L) + checkRule(query, expected) + } + + test("SPARK-23500: Simplify map ops that are not top nodes") { + val query = + LocalRelation('id.long) + .select( + CreateMap(Seq( + "r1", 'id, + "r2", 'id + 1L)) as "m") + .select( + GetMapValue('m, "r1") as "a1", + GetMapValue('m, "r32") as "a2") + .where('id > 0L) + .select('a1, 'a2) + + val expected = + LocalRelation('id.long).select( + 'id as "a1", + Literal.create(null, LongType) as "a2") + .where('id > 0L) + checkRule(query, expected) } test("SPARK-23500: Simplify complex ops that aren't at the plan root") { val structRel = relation .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") - .groupBy($"foo")("1").analyze + .groupBy($"foo")("1") val structExpected = relation .select('nullable_id as "foo") - .groupBy($"foo")("1").analyze - comparePlans(Optimizer execute structRel, structExpected) + .groupBy($"foo")("1") + checkRule(structRel, structExpected) // These tests must use nullable attributes from the base relation for the following reason: // in the 'original' plans below, the Aggregate node produced by groupBy() has a @@ -351,17 +389,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") - .groupBy($"a1")("1").analyze - val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze - comparePlans(Optimizer execute arrayRel, arrayExpected) + .groupBy($"a1")("1") + val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1") + checkRule(arrayRel, arrayExpected) val mapRel = relation .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") - .groupBy($"m1")("1").analyze + .groupBy($"m1")("1") val mapExpected = relation .select('nullable_id as "m1") - .groupBy($"m1")("1").analyze - comparePlans(Optimizer execute mapRel, mapExpected) + .groupBy($"m1")("1") + checkRule(mapRel, mapExpected) } test("SPARK-23500: Ensure that aggregation expressions are not simplified") { @@ -369,11 +407,47 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // grouping exprs so aren't tested here. val structAggRel = relation.groupBy( CreateNamedStruct(Seq("att1", 'nullable_id)))( - GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze - comparePlans(Optimizer execute structAggRel, structAggRel) + GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)) + checkRule(structAggRel, structAggRel) val arrayAggRel = relation.groupBy( - CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze - comparePlans(Optimizer execute arrayAggRel, arrayAggRel) + CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) + checkRule(arrayAggRel, arrayAggRel) + } + + test("SPARK-23500: do not simplify maps in Aggregate expressions") { + // This could be done if we had a more complex rule that checks that + // the CreateMap does not come from key. + val originalQuery = relation + .groupBy('id)( + GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" + ) + checkRule(originalQuery, originalQuery) + } + + test("SPARK-23500: namedStruct and getField in the same Project #1") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b) + .select('s1 getField "col2" as 's1Col2, + namedStruct("col1", 'a, "col2", 'b).as("s2")) + .select('s1Col2, 's2 getField "col2" as 's2Col2) + val correctAnswer = + testRelation + .select('c as 's1Col2, 'b as 's2Col2) + checkRule(originalQuery, correctAnswer) + } + + test("SPARK-23500: namedStruct and getField in the same Project #2") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2, + namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1) + val correctAnswer = + testRelation + .select('c as 'sCol2, 'a as 'sCol1) + checkRule(originalQuery, correctAnswer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala new file mode 100644 index 000000000000..b74fe2f90df2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.test.SharedSQLContext + +class ComplexTypesSuite extends QueryTest with SharedSQLContext { + + override def beforeAll() { + super.beforeAll() + spark.range(10).selectExpr( + "id + 1 as i1", "id + 2 as i2", "id + 3 as i3", "id + 4 as i4", "id + 5 as i5") + .write.saveAsTable("tab") + } + + override def afterAll() { + try { + spark.sql("DROP TABLE IF EXISTS tab") + } finally { + super.afterAll() + } + } + + def checkNamedStruct(plan: LogicalPlan, expectedCount: Int): Unit = { + var count = 0 + plan.foreach { operator => + operator.transformExpressions { + case c: CreateNamedStruct => + count += 1 + c + } + } + + if (expectedCount != count) { + fail(s"expect $expectedCount CreateNamedStruct but got $count.") + } + } + + test("simple case") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4) as col2") + .filter("col2.c > 11").selectExpr("col1.a") + checkAnswer(df, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("named_struct is used in the top Project") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .selectExpr("col1.a", "col1") + .filter("col1.a > 8") + checkAnswer(df, Row(9, Row(9, 10)) :: Row(10, Row(10, 11)) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 1) + + val df1 = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .sort("col1") + .selectExpr("col1.a") + .filter("col1.a > 8") + checkAnswer(df1, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 1) + } + + test("expression in named_struct") { + val df = spark.table("tab") + .selectExpr("i5", "struct(i1 as exp, i2, i3) as cola") + .selectExpr("cola.exp", "cola.i3").filter("cola.i3 > 10") + checkAnswer(df, Row(9, 11) :: Row(10, 12) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("i5", "struct(i1 + 1 as exp, i2, i3) as cola") + .selectExpr("cola.i3").filter("cola.exp > 10") + checkAnswer(df1, Row(12) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("nested case") { + val df = spark.table("tab") + .selectExpr("struct(struct(i2, i3) as exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.exp.i2 > 10") + checkAnswer(df, Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("struct(i2, i3) as exp", "i4") + .selectExpr("struct(exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.i4 > 11") + checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } +} From ec470775ae4c901d02f6f9f6a7d0f85a3ebc18be Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 27 Mar 2018 23:25:26 -0700 Subject: [PATCH 2/4] fix --- .../spark/sql/catalyst/optimizer/complexTypesSuite.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index d75439c5bf22..187004e15487 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -338,13 +338,11 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { numFields = 1, containsNull = false), ordinal = 1) as "a2") - .where('id > 0L) val expected = LocalRelation('id.long) .select( ('id + 1L) as "a1", ('id + 1L) as "a2") - .where('id > 0L) checkRule(query, expected) } @@ -413,9 +411,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val arrayAggRel = relation.groupBy( CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) checkRule(arrayAggRel, arrayAggRel) - } - test("SPARK-23500: do not simplify maps in Aggregate expressions") { // This could be done if we had a more complex rule that checks that // the CreateMap does not come from key. val originalQuery = relation From 888645e59657640b7049e91af5e4790e1d95ef04 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 29 Mar 2018 16:18:24 -0700 Subject: [PATCH 3/4] fix --- .../apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 187004e15487..1c07a3ba62e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -338,11 +338,13 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { numFields = 1, containsNull = false), ordinal = 1) as "a2") + .orderBy('id.asc) val expected = LocalRelation('id.long) .select( ('id + 1L) as "a1", ('id + 1L) as "a2") + .orderBy('id.asc) checkRule(query, expected) } From 5b990e5ff32e4ffee2c5cf9fa1df3690c37b429d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 29 Mar 2018 23:58:16 -0700 Subject: [PATCH 4/4] fix --- .../spark/sql/catalyst/optimizer/complexTypesSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 1c07a3ba62e1..21ed987627b3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -358,14 +358,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetMapValue('m, "r1") as "a1", GetMapValue('m, "r32") as "a2") - .where('id > 0L) + .orderBy('id.asc) .select('a1, 'a2) val expected = LocalRelation('id.long).select( 'id as "a1", Literal.create(null, LongType) as "a2") - .where('id > 0L) + .orderBy('id.asc) checkRule(query, expected) }