Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,42 +47,46 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
SimplifyExtractValueOps) :: Nil
}

val idAtt = ('id).long.notNull
val nullableIdAtt = ('nullable_id).long
private val idAtt = ('id).long.notNull
private val nullableIdAtt = ('nullable_id).long

lazy val relation = LocalRelation(idAtt, nullableIdAtt)
private val relation = LocalRelation(idAtt, nullableIdAtt)
private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int)

private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = {
val optimized = Optimizer.execute(originalQuery.analyze)
assert(optimized.resolved, "optimized plans must be still resolvable")
comparePlans(optimized, correctAnswer.analyze)
}

test("explicit get from namedStruct") {
val query = relation
.select(
GetStructField(
CreateNamedStruct(Seq("att", 'id )),
0,
None) as "outerAtt").analyze
val expected = relation.select('id as "outerAtt").analyze
None) as "outerAtt")
val expected = relation.select('id as "outerAtt")

comparePlans(Optimizer execute query, expected)
checkRule(query, expected)
}

test("explicit get from named_struct- expression maintains original deduced alias") {
val query = relation
.select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None))
.analyze

val expected = relation
.select('id as "named_struct(att, id).att")
.analyze

comparePlans(Optimizer execute query, expected)
checkRule(query, expected)
}

test("collapsed getStructField ontop of namedStruct") {
val query = relation
.select(CreateNamedStruct(Seq("att", 'id)) as "struct1")
.select(GetStructField('struct1, 0, None) as "struct1Att")
.analyze
val expected = relation.select('id as "struct1Att").analyze
comparePlans(Optimizer execute query, expected)
val expected = relation.select('id as "struct1Att")
checkRule(query, expected)
}

test("collapse multiple CreateNamedStruct/GetStructField pairs") {
Expand All @@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.select(
GetStructField('struct1, 0, None) as "struct1Att1",
GetStructField('struct1, 1, None) as "struct1Att2")
.analyze

val expected =
relation.
select(
'id as "struct1Att1",
('id * 'id) as "struct1Att2")
.analyze

comparePlans(Optimizer execute query, expected)
checkRule(query, expected)
}

test("collapsed2 - deduced names") {
Expand All @@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.select(
GetStructField('struct1, 0, None),
GetStructField('struct1, 1, None))
.analyze

val expected =
relation.
select(
'id as "struct1.att1",
('id * 'id) as "struct1.att2")
.analyze

comparePlans(Optimizer execute query, expected)
checkRule(query, expected)
}

test("simplified array ops") {
Expand All @@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
1,
false),
1) as "a4")
.analyze

val expected = relation
.select(
Expand All @@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
"att2", (('id + 1L) * ('id + 1L)))) as "a2",
('id + 1L) as "a3",
('id + 1L) as "a4")
.analyze
comparePlans(Optimizer execute query, expected)
checkRule(query, expected)
}

test("SPARK-22570: CreateArray should not create a lot of global variables") {
Expand All @@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
GetStructField(GetMapValue('m, "r1"), 0, None) as "a2",
GetMapValue('m, "r32") as "a3",
GetStructField(GetMapValue('m, "r32"), 0, None) as "a4")
.analyze

val expected =
relation.select(
Expand All @@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
)
) as "a3",
Literal.create(null, LongType) as "a4")
.analyze
comparePlans(Optimizer execute query, expected)
checkRule(query, expected)
}

test("simplify map ops, constant lookup, dynamic keys") {
Expand All @@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
13L) as "a")
.analyze

val expected = relation
.select(
Expand All @@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
(EqualTo(13L, ('id + 1L)), ('id + 2L)),
(EqualTo(13L, ('id + 2L)), ('id + 3L)),
(Literal(true), 'id))) as "a")
.analyze
comparePlans(Optimizer execute query, expected)
checkRule(query, expected)
}

test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") {
Expand All @@ -240,16 +234,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
('id + 3L)) as "a")
.analyze
val expected = relation
.select(
CaseWhen(Seq(
(EqualTo('id + 3L, 'id), ('id + 1L)),
(EqualTo('id + 3L, ('id + 1L)), ('id + 2L)),
(EqualTo('id + 3L, ('id + 2L)), ('id + 3L)),
(Literal(true), ('id + 4L)))) as "a")
.analyze
comparePlans(Optimizer execute query, expected)
checkRule(query, expected)
}

test("simplify map ops, no positive match") {
Expand All @@ -263,16 +255,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
'id + 30L) as "a")
.analyze
val expected = relation.select(
CaseWhen(Seq(
(EqualTo('id + 30L, 'id), ('id + 1L)),
(EqualTo('id + 30L, ('id + 1L)), ('id + 2L)),
(EqualTo('id + 30L, ('id + 2L)), ('id + 3L)),
(EqualTo('id + 30L, ('id + 3L)), ('id + 4L)),
(EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a")
.analyze
comparePlans(Optimizer execute rel, expected)
checkRule(rel, expected)
}

test("simplify map ops, constant lookup, mixed keys, eliminated constants") {
Expand All @@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
13L) as "a")
.analyze

val expected = relation
.select(
Expand All @@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 2L), ('id + 3L),
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))) as "a")
.analyze

comparePlans(Optimizer execute rel, expected)
checkRule(rel, expected)
}

test("simplify map ops, potential dynamic match with null value + an absolute constant match") {
Expand All @@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
2L ) as "a")
.analyze

val expected = relation
.select(
Expand All @@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
// but it cannot override a potential match with ('id + 2L),
// which is exactly what [[Coalesce]] would do in this case.
(Literal.TrueLiteral, 'id))) as "a")
.analyze
comparePlans(Optimizer execute rel, expected)
checkRule(rel, expected)
}

test("SPARK-23500: Simplify array ops that are not at the top node") {
val query = LocalRelation('id.long)
.select(
CreateArray(Seq(
CreateNamedStruct(Seq(
"att1", 'id,
"att2", 'id * 'id)),
CreateNamedStruct(Seq(
"att1", 'id + 1,
"att2", ('id + 1) * ('id + 1))
))
) as "arr")
.select(
GetStructField(GetArrayItem('arr, 1), 0, None) as "a1",
GetArrayItem(
GetArrayStructFields('arr,
StructField("att1", LongType, nullable = false),
ordinal = 0,
numFields = 1,
containsNull = false),
ordinal = 1) as "a2")
.orderBy('id.asc)

val expected = LocalRelation('id.long)
.select(
('id + 1L) as "a1",
('id + 1L) as "a2")
.orderBy('id.asc)
checkRule(query, expected)
}

test("SPARK-23500: Simplify map ops that are not top nodes") {
val query =
LocalRelation('id.long)
.select(
CreateMap(Seq(
"r1", 'id,
"r2", 'id + 1L)) as "m")
.select(
GetMapValue('m, "r1") as "a1",
GetMapValue('m, "r32") as "a2")
.where('id > 0L)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use sort here too?

.select('a1, 'a2)

val expected =
LocalRelation('id.long).select(
'id as "a1",
Literal.create(null, LongType) as "a2")
.where('id > 0L)
checkRule(query, expected)
}

test("SPARK-23500: Simplify complex ops that aren't at the plan root") {
val structRel = relation
.select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo")
.groupBy($"foo")("1").analyze
.groupBy($"foo")("1")
val structExpected = relation
.select('nullable_id as "foo")
.groupBy($"foo")("1").analyze
comparePlans(Optimizer execute structRel, structExpected)
.groupBy($"foo")("1")
checkRule(structRel, structExpected)

// These tests must use nullable attributes from the base relation for the following reason:
// in the 'original' plans below, the Aggregate node produced by groupBy() has a
Expand All @@ -351,29 +389,63 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
// SPARK-23634.
val arrayRel = relation
.select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1")
.groupBy($"a1")("1").analyze
val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze
comparePlans(Optimizer execute arrayRel, arrayExpected)
.groupBy($"a1")("1")
val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1")
checkRule(arrayRel, arrayExpected)

val mapRel = relation
.select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1")
.groupBy($"m1")("1").analyze
.groupBy($"m1")("1")
val mapExpected = relation
.select('nullable_id as "m1")
.groupBy($"m1")("1").analyze
comparePlans(Optimizer execute mapRel, mapExpected)
.groupBy($"m1")("1")
checkRule(mapRel, mapExpected)
}

test("SPARK-23500: Ensure that aggregation expressions are not simplified") {
// Make sure that aggregation exprs are correctly ignored. Maps can't be used in
// grouping exprs so aren't tested here.
val structAggRel = relation.groupBy(
CreateNamedStruct(Seq("att1", 'nullable_id)))(
GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze
comparePlans(Optimizer execute structAggRel, structAggRel)
GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None))
checkRule(structAggRel, structAggRel)

val arrayAggRel = relation.groupBy(
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze
comparePlans(Optimizer execute arrayAggRel, arrayAggRel)
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0))
checkRule(arrayAggRel, arrayAggRel)

// This could be done if we had a more complex rule that checks that
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we merge it with the previous test case?

// the CreateMap does not come from key.
val originalQuery = relation
.groupBy('id)(
GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a"
)
checkRule(originalQuery, originalQuery)
}

test("SPARK-23500: namedStruct and getField in the same Project #1") {
val originalQuery =
testRelation
.select(
namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b)
.select('s1 getField "col2" as 's1Col2,
namedStruct("col1", 'a, "col2", 'b).as("s2"))
.select('s1Col2, 's2 getField "col2" as 's2Col2)
val correctAnswer =
testRelation
.select('c as 's1Col2, 'b as 's2Col2)
checkRule(originalQuery, correctAnswer)
}

test("SPARK-23500: namedStruct and getField in the same Project #2") {
val originalQuery =
testRelation
.select(
namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2,
namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1)
val correctAnswer =
testRelation
.select('c as 'sCol2, 'a as 'sCol1)
checkRule(originalQuery, correctAnswer)
}
}
Loading