diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index e3541dc7ee730..03f9da66cab48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1547,14 +1547,27 @@ case class BRound(child: Expression, scale: Expression) } object WidthBucket { - def computeBucketNumber(value: Double, min: Double, max: Double, numBucket: Long): jl.Long = { - if (numBucket <= 0 || numBucket == Long.MaxValue || jl.Double.isNaN(value) || min == max || - jl.Double.isNaN(min) || jl.Double.isInfinite(min) || - jl.Double.isNaN(max) || jl.Double.isInfinite(max)) { - return null + if (isNull(value, min, max, numBucket)) { + null + } else { + computeBucketNumberNotNull(value, min, max, numBucket) } + } + /** This function is called by generated Java code, so it needs to be public. */ + def isNull(value: Double, min: Double, max: Double, numBucket: Long): Boolean = { + numBucket <= 0 || + numBucket == Long.MaxValue || + jl.Double.isNaN(value) || + min == max || + jl.Double.isNaN(min) || jl.Double.isInfinite(min) || + jl.Double.isNaN(max) || jl.Double.isInfinite(max) + } + + /** This function is called by generated Java code, so it needs to be public. */ + def computeBucketNumberNotNull( + value: Double, min: Double, max: Double, numBucket: Long): jl.Long = { val lower = Math.min(min, max) val upper = Math.max(min, max) @@ -1666,9 +1679,14 @@ case class WidthBucket( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (input, min, max, numBucket) => - "org.apache.spark.sql.catalyst.expressions.WidthBucket" + - s".computeBucketNumber($input, $min, $max, $numBucket)") + nullSafeCodeGen(ctx, ev, (input, min, max, numBucket) => { + s"""${ev.isNull} = org.apache.spark.sql.catalyst.expressions.WidthBucket + | .isNull($input, $min, $max, $numBucket); + |if (!${ev.isNull}) { + | ${ev.value} = org.apache.spark.sql.catalyst.expressions.WidthBucket + | .computeBucketNumberNotNull($input, $min, $max, $numBucket); + |}""".stripMargin + }) } override def first: Expression = value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index bd133e75781cc..ea0d619ad4c15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -760,4 +760,32 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WidthBucket(Literal(v), Literal(s), Literal(e), Literal(n)), expected) } } + + test("SPARK-37388: width_bucket") { + val nullDouble = Literal.create(null, DoubleType) + val nullLong = Literal.create(null, LongType) + + checkEvaluation(WidthBucket(5.35, 0.024, 10.06, 5L), 3L) + checkEvaluation(WidthBucket(-2.1, 1.3, 3.4, 3L), 0L) + checkEvaluation(WidthBucket(8.1, 0.0, 5.7, 4L), 5L) + checkEvaluation(WidthBucket(-0.9, 5.2, 0.5, 2L), 3L) + checkEvaluation(WidthBucket(nullDouble, 0.024, 10.06, 5L), null) + checkEvaluation(WidthBucket(5.35, nullDouble, 10.06, 5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, nullDouble, 5L), null) + checkEvaluation(WidthBucket(5.35, nullDouble, nullDouble, 5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, 10.06, nullLong), null) + checkEvaluation(WidthBucket(nullDouble, nullDouble, nullDouble, nullLong), null) + checkEvaluation(WidthBucket(5.35, 0.024, 10.06, -5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, 10.06, Long.MaxValue), null) + checkEvaluation(WidthBucket(Double.NaN, 0.024, 10.06, 5L), null) + checkEvaluation(WidthBucket(Double.NegativeInfinity, 0.024, 10.06, 5L), 0L) + checkEvaluation(WidthBucket(Double.PositiveInfinity, 0.024, 10.06, 5L), 6L) + checkEvaluation(WidthBucket(5.35, 0.024, 0.024, 5L), null) + checkEvaluation(WidthBucket(5.35, Double.NaN, 10.06, 5L), null) + checkEvaluation(WidthBucket(5.35, Double.NegativeInfinity, 10.06, 5L), null) + checkEvaluation(WidthBucket(5.35, Double.PositiveInfinity, 10.06, 5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, Double.NaN, 5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, Double.NegativeInfinity, 5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, Double.PositiveInfinity, 5L), null) + } }