Skip to content

Commit 60b89cf

Browse files
gatorsmiledongjoon-hyun
authored andcommitted
[SPARK-28361][SQL][TEST] Test equality of generated code with id in class name
A code gen test in WholeStageCodeGenSuite was flaky because it used the codegen metrics class to test if the generated code for equivalent plans was identical under a particular flag. This patch switches the test to compare the generated code directly. N/A Closes apache#25131 from gatorsmile/WholeStageCodegenSuite. Authored-by: gatorsmile <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent aa41dce commit 60b89cf

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import org.apache.spark.metrics.source.CodegenMetrics
21-
import org.apache.spark.sql.{QueryTest, Row, SaveMode}
20+
import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
2221
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
2322
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
24-
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
2523
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
2624
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
2725
import org.apache.spark.sql.expressions.scalalang.typed
@@ -145,10 +143,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
145143
.select("int")
146144

147145
val plan = df.queryExecution.executedPlan
148-
assert(!plan.find(p =>
146+
assert(plan.find(p =>
149147
p.isInstanceOf[WholeStageCodegenExec] &&
150148
p.asInstanceOf[WholeStageCodegenExec].child.children(0)
151-
.isInstanceOf[SortMergeJoinExec]).isDefined)
149+
.isInstanceOf[SortMergeJoinExec]).isEmpty)
152150
assert(df.collect() === Array(Row(1), Row(2)))
153151
}
154152
}
@@ -181,6 +179,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
181179
wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
182180
}
183181

182+
def genCode(ds: Dataset[_]): Seq[CodeAndComment] = {
183+
val plan = ds.queryExecution.executedPlan
184+
val wholeStageCodeGenExecs = plan.collect { case p: WholeStageCodegenExec => p }
185+
assert(wholeStageCodeGenExecs.nonEmpty, "WholeStageCodegenExec is expected")
186+
wholeStageCodeGenExecs.map(_.doCodeGen()._2)
187+
}
188+
184189
ignore("SPARK-21871 check if we can get large code size when compiling too long functions") {
185190
val codeWithShortFunctions = genGroupByCode(3)
186191
val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions)
@@ -241,9 +246,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
241246
val df = spark.range(100)
242247
val join = df.join(df, "id")
243248
val plan = join.queryExecution.executedPlan
244-
assert(!plan.find(p =>
249+
assert(plan.find(p =>
245250
p.isInstanceOf[WholeStageCodegenExec] &&
246-
p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined,
251+
p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isEmpty,
247252
"codegen stage IDs should be preserved through ReuseExchange")
248253
checkAnswer(join, df.toDF)
249254
}
@@ -253,18 +258,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
253258
import testImplicits._
254259

255260
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") {
256-
val bytecodeSizeHisto = CodegenMetrics.METRIC_COMPILATION_TIME
257-
258-
// the same query run twice should hit the codegen cache
259-
spark.range(3).select('id + 2).collect
260-
val after1 = bytecodeSizeHisto.getCount
261-
spark.range(3).select('id + 2).collect
262-
val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately
263-
// bytecodeSizeHisto's count is always monotonically increasing if new compilation to
264-
// bytecode had occurred. If the count stayed the same that means we've got a cache hit.
265-
assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected")
266-
267-
// a different query can result in codegen cache miss, that's by design
261+
// the same query run twice should produce identical code, which would imply a hit in
262+
// the generated code cache.
263+
val ds1 = spark.range(3).select('id + 2)
264+
val code1 = genCode(ds1)
265+
val ds2 = spark.range(3).select('id + 2)
266+
val code2 = genCode(ds2) // same query shape as above, deliberately
267+
assert(code1 == code2, "Should produce same code")
268268
}
269269
}
270270

0 commit comments

Comments
 (0)