Skip to content

Commit 931fa28

Browse files
committed
fix UT failure and add new test
1 parent 8ee56bb commit 931fa28

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -404,21 +404,26 @@ abstract class HashExpression[E] extends Expression {
404404
input: String,
405405
result: String,
406406
fields: Array[StructField]): String = {
407+
val tmpInput = ctx.freshName("input")
407408
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
408-
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
409+
nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx)
409410
}
410411
val hashResultType = CodeGenerator.javaType(dataType)
411-
ctx.splitExpressions(
412+
val code = ctx.splitExpressions(
412413
expressions = fieldsHash,
413414
funcName = "computeHashForStruct",
414-
arguments = Seq(hashResultType -> result),
415+
arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result),
415416
returnType = hashResultType,
416417
makeSplitFunction = body =>
417418
s"""
418419
|$body
419420
|return $result;
420421
""".stripMargin,
421422
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
423+
s"""
424+
|final InternalRow $tmpInput = $input;
425+
|$code
426+
""".stripMargin
422427
}
423428

424429
@tailrec
@@ -778,21 +783,22 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
778783
input: String,
779784
result: String,
780785
fields: Array[StructField]): String = {
786+
val tmpInput = ctx.freshName("input")
781787
val childResult = ctx.freshName("childResult")
782788
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
783789
val computeFieldHash = nullSafeElementHash(
784-
input, index.toString, field.nullable, field.dataType, childResult, ctx)
790+
tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx)
785791
s"""
786792
|$childResult = 0;
787793
|$computeFieldHash
788794
|$result = (31 * $result) + $childResult;
789795
""".stripMargin
790796
}
791797

792-
s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
798+
val code = ctx.splitExpressions(
793799
expressions = fieldsHash,
794800
funcName = "computeHashForStruct",
795-
arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result),
801+
arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result),
796802
returnType = CodeGenerator.JAVA_INT,
797803
makeSplitFunction = body =>
798804
s"""
@@ -801,6 +807,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
801807
|return $result;
802808
""".stripMargin,
803809
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
810+
s"""
811+
|final InternalRow $tmpInput = $input;
812+
|${CodeGenerator.JAVA_INT} $childResult = 0;
813+
|$code
814+
""".stripMargin
804815
}
805816
}
806817

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2831,4 +2831,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
28312831
}
28322832
}
28332833
}
2834+
2835+
test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen issue") {
2836+
withView("spark_25084") {
2837+
val count = 1000
2838+
val df = spark.range(count)
2839+
val columns = (0 until 400).map{ i => s"id as id$i" }
2840+
val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",")
2841+
df.selectExpr(columns : _*).createTempView("spark_25084")
2842+
assert(
2843+
spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count()
2844+
=== count)
2845+
}
2846+
}
28342847
}

0 commit comments

Comments
 (0)