Skip to content

Commit c377e49

Browse files
committed
[SPARK-16489][SQL] Guard against variable reuse mistakes in expression code generation
## What changes were proposed in this pull request? In code generation, it is incorrect for expressions to reuse variable names across different instances of itself. As an example, SPARK-16488 reports a bug in which pmod expression reuses variable name "r". This patch updates ExpressionEvalHelper test harness to always project two instances of the same expression, which will help us catch variable reuse problems in expression unit tests. This patch also fixes the bug in crc32 expression. ## How was this patch tested? This is a test harness change, but I also created a new test suite for testing the test harness. Author: Reynold Xin <rxin@databricks.com> Closes apache#14146 from rxin/SPARK-16489.
1 parent 5ad68ba commit c377e49

File tree

4 files changed

+68
-22
lines changed

4 files changed

+68
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,12 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp
175175

176176
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
177177
val CRC32 = "java.util.zip.CRC32"
178+
val checksum = ctx.freshName("checksum")
178179
nullSafeCodeGen(ctx, ev, value => {
179180
s"""
180-
$CRC32 checksum = new $CRC32();
181-
checksum.update($value, 0, $value.length);
182-
${ev.value} = checksum.getValue();
181+
$CRC32 $checksum = new $CRC32();
182+
$checksum.update($value, 0, $value.length);
183+
${ev.value} = $checksum.getValue();
183184
"""
184185
})
185186
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,23 +132,28 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
132132
expression: Expression,
133133
expected: Any,
134134
inputRow: InternalRow = EmptyRow): Unit = {
135-
135+
// SPARK-16489 Explicitly doing code generation twice so code gen will fail if
136+
// some expression is reusing variable names across different instances.
137+
// This behavior is tested in ExpressionEvalHelperSuite.
136138
val plan = generateProject(
137-
GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
139+
GenerateUnsafeProjection.generate(
140+
Alias(expression, s"Optimized($expression)1")() ::
141+
Alias(expression, s"Optimized($expression)2")() :: Nil),
138142
expression)
139143

140144
val unsafeRow = plan(inputRow)
141145
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
142146

143147
if (expected == null) {
144148
if (!unsafeRow.isNullAt(0)) {
145-
val expectedRow = InternalRow(expected)
149+
val expectedRow = InternalRow(expected, expected)
146150
fail("Incorrect evaluation in unsafe mode: " +
147151
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
148152
}
149153
} else {
150-
val lit = InternalRow(expected)
151-
val expectedRow = UnsafeProjection.create(Array(expression.dataType)).apply(lit)
154+
val lit = InternalRow(expected, expected)
155+
val expectedRow =
156+
UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
152157
if (unsafeRow != expectedRow) {
153158
fail("Incorrect evaluation in unsafe mode: " +
154159
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
23+
import org.apache.spark.sql.types.{DataType, IntegerType}
24+
25+
/**
26+
* A test suite for testing [[ExpressionEvalHelper]].
27+
*
28+
* Yes, we should write test cases for test harnesses, in case
29+
* they have behaviors that are easy to break.
30+
*/
31+
class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper {
32+
33+
test("SPARK-16489 checkEvaluation should fail if expression reuses variable names") {
34+
val e = intercept[RuntimeException] { checkEvaluation(BadCodegenExpression(), 10) }
35+
assert(e.getMessage.contains("some_variable"))
36+
}
37+
}
38+
39+
/**
40+
* An expression that generates bad code (variable name "some_variable" is not unique across
41+
* instances of the expression.
42+
*/
43+
case class BadCodegenExpression() extends LeafExpression {
44+
override def nullable: Boolean = false
45+
override def eval(input: InternalRow): Any = 10
46+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
47+
ev.copy(code =
48+
s"""
49+
|int some_variable = 11;
50+
|int ${ev.value} = 10;
51+
""".stripMargin)
52+
}
53+
override def dataType: DataType = IntegerType
54+
}

sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -449,20 +449,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
449449
}
450450
}
451451

452-
test("pmod with partitionBy") {
453-
val spark = this.spark
454-
import spark.implicits._
455-
456-
case class Test(a: Int, b: String)
457-
val data = Seq((0, "a"), (1, "b"), (1, "a"))
458-
spark.createDataset(data).createOrReplaceTempView("test")
459-
sql("select * from test distribute by pmod(_1, 2)")
460-
.write
461-
.partitionBy("_2")
462-
.mode("overwrite")
463-
.parquet(dir)
464-
}
465-
466452
private def testRead(
467453
df: => DataFrame,
468454
expectedResult: Seq[String],

0 commit comments

Comments
 (0)