Skip to content

Commit 716d88f

Browse files
committed
Address comments.
1 parent f883c2b commit 716d88f

File tree

2 files changed

+30
-23
lines changed

2 files changed

+30
-23
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
199199
inputRow: InternalRow = EmptyRow): Unit = {
200200
for (fallbackMode <- Seq("CODEGEN_ONLY", "NO_CODEGEN")) {
201201
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode) {
202-
val factory = UnsafeProjection
203-
val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory)
202+
val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow)
204203
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
205204

206205
if (expected == null) {
@@ -212,7 +211,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
212211
} else {
213212
val lit = InternalRow(expected, expected)
214213
val expectedRow =
215-
factory.create(Array(expression.dataType, expression.dataType)).apply(lit)
214+
UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
216215
if (unsafeRow != expectedRow) {
217216
fail("Incorrect evaluation in unsafe mode: " +
218217
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
@@ -224,8 +223,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
224223

225224
protected def evaluateWithUnsafeProjection(
226225
expression: Expression,
227-
inputRow: InternalRow = EmptyRow,
228-
factory: UnsafeProjection.type = UnsafeProjection): InternalRow = {
226+
inputRow: InternalRow = EmptyRow): InternalRow = {
229227
// SPARK-16489 Explicitly doing code generation twice so code gen will fail if
230228
// some expression is reusing variable names across different instances.
231229
// This behavior is tested in ExpressionEvalHelperSuite.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
3535

3636
private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size)
3737

38-
private def testWithFactory(
39-
name: String)(
40-
f: UnsafeProjection.type => Unit): Unit = {
41-
val factory = UnsafeProjection
42-
test(name) {
43-
for (fallbackMode <- Seq("CODEGEN_ONLY", "NO_CODEGEN")) {
38+
private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = {
39+
for (fallbackMode <- Seq("CODEGEN_ONLY", "NO_CODEGEN")) {
40+
test(name + " with " + fallbackMode) {
4441
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode) {
45-
f(factory)
42+
f
4643
}
4744
}
4845
}
4946
}
5047

51-
testWithFactory("basic conversion with only primitive types") { factory =>
48+
testBothCodegenAndInterpreted("basic conversion with only primitive types") {
49+
val factory = UnsafeProjection
5250
val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
5351
val converter = factory.create(fieldTypes)
5452
val row = new SpecificInternalRow(fieldTypes)
@@ -85,7 +83,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
8583
assert(unsafeRow2.getInt(2) === 2)
8684
}
8785

88-
testWithFactory("basic conversion with primitive, string and binary types") { factory =>
86+
testBothCodegenAndInterpreted("basic conversion with primitive, string and binary types") {
87+
val factory = UnsafeProjection
8988
val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType)
9089
val converter = factory.create(fieldTypes)
9190

@@ -104,7 +103,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
104103
assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8))
105104
}
106105

107-
testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory =>
106+
testBothCodegenAndInterpreted(
107+
"basic conversion with primitive, string, date and timestamp types") {
108+
val factory = UnsafeProjection
108109
val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType)
109110
val converter = factory.create(fieldTypes)
110111

@@ -133,7 +134,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
133134
(Timestamp.valueOf("2015-06-22 08:10:25"))
134135
}
135136

136-
testWithFactory("null handling") { factory =>
137+
testBothCodegenAndInterpreted("null handling") {
138+
val factory = UnsafeProjection
137139
val fieldTypes: Array[DataType] = Array(
138140
NullType,
139141
BooleanType,
@@ -254,7 +256,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
254256
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
255257
}
256258

257-
testWithFactory("NaN canonicalization") { factory =>
259+
testBothCodegenAndInterpreted("NaN canonicalization") {
260+
val factory = UnsafeProjection
258261
val fieldTypes: Array[DataType] = Array(FloatType, DoubleType)
259262

260263
val row1 = new SpecificInternalRow(fieldTypes)
@@ -269,7 +272,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
269272
assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
270273
}
271274

272-
testWithFactory("basic conversion with struct type") { factory =>
275+
testBothCodegenAndInterpreted("basic conversion with struct type") {
276+
val factory = UnsafeProjection
273277
val fieldTypes: Array[DataType] = Array(
274278
new StructType().add("i", IntegerType),
275279
new StructType().add("nest", new StructType().add("l", LongType))
@@ -331,7 +335,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
331335
assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
332336
}
333337

334-
testWithFactory("basic conversion with array type") { factory =>
338+
testBothCodegenAndInterpreted("basic conversion with array type") {
339+
val factory = UnsafeProjection
335340
val fieldTypes: Array[DataType] = Array(
336341
ArrayType(IntegerType),
337342
ArrayType(ArrayType(IntegerType))
@@ -361,7 +366,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
361366
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
362367
}
363368

364-
testWithFactory("basic conversion with map type") { factory =>
369+
testBothCodegenAndInterpreted("basic conversion with map type") {
370+
val factory = UnsafeProjection
365371
val fieldTypes: Array[DataType] = Array(
366372
MapType(IntegerType, IntegerType),
367373
MapType(IntegerType, MapType(IntegerType, IntegerType))
@@ -407,7 +413,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
407413
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
408414
}
409415

410-
testWithFactory("basic conversion with struct and array") { factory =>
416+
testBothCodegenAndInterpreted("basic conversion with struct and array") {
417+
val factory = UnsafeProjection
411418
val fieldTypes: Array[DataType] = Array(
412419
new StructType().add("arr", ArrayType(IntegerType)),
413420
ArrayType(new StructType().add("l", LongType))
@@ -446,7 +453,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
446453
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
447454
}
448455

449-
testWithFactory("basic conversion with struct and map") { factory =>
456+
testBothCodegenAndInterpreted("basic conversion with struct and map") {
457+
val factory = UnsafeProjection
450458
val fieldTypes: Array[DataType] = Array(
451459
new StructType().add("map", MapType(IntegerType, IntegerType)),
452460
MapType(IntegerType, new StructType().add("l", LongType))
@@ -492,7 +500,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
492500
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
493501
}
494502

495-
testWithFactory("basic conversion with array and map") { factory =>
503+
testBothCodegenAndInterpreted("basic conversion with array and map") {
504+
val factory = UnsafeProjection
496505
val fieldTypes: Array[DataType] = Array(
497506
ArrayType(MapType(IntegerType, IntegerType)),
498507
MapType(IntegerType, ArrayType(IntegerType))

0 commit comments

Comments
 (0)