1717
1818package org .apache .spark .sql .execution
1919
20- import org .apache .spark .metrics .source .CodegenMetrics
21- import org .apache .spark .sql .{QueryTest , Row , SaveMode }
20+ import org .apache .spark .sql .{Dataset , QueryTest , Row , SaveMode }
2221import org .apache .spark .sql .catalyst .expressions .codegen .{CodeAndComment , CodeGenerator }
2322import org .apache .spark .sql .execution .aggregate .HashAggregateExec
24- import org .apache .spark .sql .execution .columnar .InMemoryTableScanExec
2523import org .apache .spark .sql .execution .joins .BroadcastHashJoinExec
2624import org .apache .spark .sql .execution .joins .SortMergeJoinExec
2725import org .apache .spark .sql .expressions .scalalang .typed
@@ -145,10 +143,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
145143 .select(" int" )
146144
147145 val plan = df.queryExecution.executedPlan
148- assert(! plan.find(p =>
146+ assert(plan.find(p =>
149147 p.isInstanceOf [WholeStageCodegenExec ] &&
150148 p.asInstanceOf [WholeStageCodegenExec ].child.children(0 )
151- .isInstanceOf [SortMergeJoinExec ]).isDefined )
149+ .isInstanceOf [SortMergeJoinExec ]).isEmpty )
152150 assert(df.collect() === Array (Row (1 ), Row (2 )))
153151 }
154152 }
@@ -181,6 +179,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
181179 wholeStageCodeGenExec.get.asInstanceOf [WholeStageCodegenExec ].doCodeGen()._2
182180 }
183181
182+ def genCode (ds : Dataset [_]): Seq [CodeAndComment ] = {
183+ val plan = ds.queryExecution.executedPlan
184+ val wholeStageCodeGenExecs = plan.collect { case p : WholeStageCodegenExec => p }
185+ assert(wholeStageCodeGenExecs.nonEmpty, " WholeStageCodegenExec is expected" )
186+ wholeStageCodeGenExecs.map(_.doCodeGen()._2)
187+ }
188+
184189 ignore(" SPARK-21871 check if we can get large code size when compiling too long functions" ) {
185190 val codeWithShortFunctions = genGroupByCode(3 )
186191 val (_, maxCodeSize1) = CodeGenerator .compile(codeWithShortFunctions)
@@ -241,9 +246,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
241246 val df = spark.range(100 )
242247 val join = df.join(df, " id" )
243248 val plan = join.queryExecution.executedPlan
244- assert(! plan.find(p =>
249+ assert(plan.find(p =>
245250 p.isInstanceOf [WholeStageCodegenExec ] &&
246- p.asInstanceOf [WholeStageCodegenExec ].codegenStageId == 0 ).isDefined ,
251+ p.asInstanceOf [WholeStageCodegenExec ].codegenStageId == 0 ).isEmpty ,
247252 " codegen stage IDs should be preserved through ReuseExchange" )
248253 checkAnswer(join, df.toDF)
249254 }
@@ -253,18 +258,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
253258 import testImplicits ._
254259
255260 withSQLConf(SQLConf .WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME .key -> " true" ) {
256- val bytecodeSizeHisto = CodegenMetrics .METRIC_COMPILATION_TIME
257-
258- // the same query run twice should hit the codegen cache
259- spark.range(3 ).select(' id + 2 ).collect
260- val after1 = bytecodeSizeHisto.getCount
261- spark.range(3 ).select(' id + 2 ).collect
262- val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately
263- // bytecodeSizeHisto's count is always monotonically increasing if new compilation to
264- // bytecode had occurred. If the count stayed the same that means we've got a cache hit.
265- assert(after1 == after2, " Should hit codegen cache. No new compilation to bytecode expected" )
266-
267- // a different query can result in codegen cache miss, that's by design
261+ // the same query run twice should produce identical code, which would imply a hit in
262+ // the generated code cache.
263+ val ds1 = spark.range(3 ).select(' id + 2 )
264+ val code1 = genCode(ds1)
265+ val ds2 = spark.range(3 ).select(' id + 2 )
266+ val code2 = genCode(ds2) // same query shape as above, deliberately
267+ assert(code1 == code2, " Should produce same code" )
268268 }
269269 }
270270
0 commit comments