From 1658f5bf127cd58a68ae9d16250cc11aaaf64ed8 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 20 Apr 2021 22:43:08 +0800 Subject: [PATCH 1/5] throw error on corner case --- .../sql/catalyst/expressions/arithmetic.scala | 41 ++++++++++++++++--- .../sql/errors/QueryExecutionErrors.scala | 4 ++ .../ArithmeticExpressionSuite.scala | 7 ++++ 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 28851918429a..aaef7d85fa36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -404,6 +404,9 @@ trait DivModLike extends BinaryArithmetic { protected def decimalToDataTypeCodeGen(decimalResult: String): String = decimalResult + // When it is an integral divide, we need to check whether overflow happens in ANSI mode. + protected def isIntegralDivide: Boolean = false + override def nullable: Boolean = true private lazy val isZero: Any => Boolean = right.dataType match { @@ -450,6 +453,14 @@ trait DivModLike extends BinaryArithmetic { } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } + val checkIntegralDivideOverflow = if (failOnError && isIntegralDivide) { + s""" + |if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1) + | throw QueryExecutionErrors.overflowInIntegralDivideError(); + |""".stripMargin + } else { + "" + } // evaluate right first as we have a chance to skip left if right is 0 if (!left.nullable && !right.nullable) { val divByZero = if (failOnError) { @@ -465,6 +476,7 @@ trait DivModLike extends BinaryArithmetic { $divByZero } else { ${eval1.code} + $checkIntegralDivideOverflow ${ev.value} = $operation; }""") } else { @@ -486,6 +498,7 @@ trait DivModLike extends BinaryArithmetic { ${ev.isNull} = true; } else { $failOnErrorBranch + $checkIntegralDivideOverflow ${ev.value} = $operation; } }""") @@ -546,6 +559,8 @@ case class IntegralDivide( def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + override def isIntegralDivide: Boolean = true + override def inputType: AbstractDataType = TypeCollection(LongType, DecimalType) override def dataType: DataType = LongType @@ -562,13 +577,27 @@ case class IntegralDivide( case d: DecimalType => d.asIntegral.asInstanceOf[Integral[Any]] } - (x, y) => { - val res = integral.quot(x, y) - if (res == null) { - null - } else { - integral.asInstanceOf[Integral[Any]].toLong(res) + val _div = + (x: Any, y: Any) => { + val res = integral.quot(x, y) + if (res == null) { + null + } else { + integral.asInstanceOf[Integral[Any]].toLong(res) + } } + + dataType match { + case LongType if failOnError => + (x: Any, y: Any) => { + if (x == Long.MinValue && y == -1) { + throw QueryExecutionErrors.overflowInIntegralDivideError() + } + _div(x, y) + } + + case _ => + _div } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index c5a608e38da5..75f56958788c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -146,6 +146,10 @@ object QueryExecutionErrors { new ArithmeticException("Overflow in sum of decimals.") } + def overflowInIntegralDivideError(): ArithmeticException = { + new ArithmeticException("Overflow in integral divide.") + } + def mapSizeExceedArraySizeWhenZipMapError(size: Int): RuntimeException = { new RuntimeException(s"Unsuccessful try to zip maps with $size " + "unique keys due to exceeding the array size limit " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6ccd29921053..5d07912c6896 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -207,6 +207,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("IntegralDivide: throw exception on overflow under ANSI mode") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + IntegralDivide(Literal(Long.MinValue), Literal(-1L)), "Overflow in integral divide.") + } + } + test("% (Remainder)") { testNumericDataTypes { convert => val left = Literal(convert(1)) From d91fa4e0a77210534cbf7d29ac206e16c83ed0e5 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 20 Apr 2021 23:58:50 +0800 Subject: [PATCH 2/5] address comment --- .../org/apache/spark/sql/catalyst/expressions/arithmetic.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index aaef7d85fa36..ca92ac39451d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -587,7 +587,7 @@ case class IntegralDivide( } } - dataType match { + left.dataType match { case LongType if failOnError => (x: Any, y: Any) => { if (x == Long.MinValue && y == -1) { From 9ade94024d1386b0e850aa2d098fd0c48e7ae83a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 21 Apr 2021 00:10:37 +0800 Subject: [PATCH 3/5] fix --- .../sql/catalyst/expressions/arithmetic.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ca92ac39451d..ade6e4f29a0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -453,14 +453,16 @@ trait DivModLike extends BinaryArithmetic { } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } - val checkIntegralDivideOverflow = if (failOnError && isIntegralDivide) { - s""" - |if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1) - | throw QueryExecutionErrors.overflowInIntegralDivideError(); - |""".stripMargin - } else { - "" + val checkIntegralDivideOverflow = left.dataType match { + case LongType if failOnError && isIntegralDivide => + s""" + |if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1) + | throw QueryExecutionErrors.overflowInIntegralDivideError(); + |""".stripMargin + + case _ => "" } + // evaluate right first as we have a chance to skip left if right is 0 if (!left.nullable && !right.nullable) { val divByZero = if (failOnError) { From a6ad9bfdf5059f88b17d446f2dbfb31a6b065225 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 21 Apr 2021 00:30:23 +0800 Subject: [PATCH 4/5] refactor --- .../sql/catalyst/expressions/arithmetic.scala | 53 ++++++++----------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ade6e4f29a0a..e95231db4e96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -404,8 +404,8 @@ trait DivModLike extends BinaryArithmetic { protected def decimalToDataTypeCodeGen(decimalResult: String): String = decimalResult - // When it is an integral divide, we need to check whether overflow happens in ANSI mode. - protected def isIntegralDivide: Boolean = false + // Whether we should check overflow or not in ANSI mode. + protected def checkDivideOverflow: Boolean = false override def nullable: Boolean = true @@ -428,6 +428,9 @@ trait DivModLike extends BinaryArithmetic { // when we reach here, failOnError must bet true. throw QueryExecutionErrors.divideByZeroError } + if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) { + throw QueryExecutionErrors.overflowInIntegralDivideError() + } evalOperation(input1, input2) } } @@ -453,14 +456,13 @@ trait DivModLike extends BinaryArithmetic { } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } - val checkIntegralDivideOverflow = left.dataType match { - case LongType if failOnError && isIntegralDivide => - s""" - |if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1) - | throw QueryExecutionErrors.overflowInIntegralDivideError(); - |""".stripMargin - - case _ => "" + val checkIntegralDivideOverflow = if (checkDivideOverflow) { + s""" + |if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1) + | throw QueryExecutionErrors.overflowInIntegralDivideError(); + |""".stripMargin + } else { + "" } // evaluate right first as we have a chance to skip left if right is 0 @@ -561,7 +563,10 @@ case class IntegralDivide( def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) - override def isIntegralDivide: Boolean = true + override def checkDivideOverflow: Boolean = left.dataType match { + case LongType if failOnError => true + case _ => false + } override def inputType: AbstractDataType = TypeCollection(LongType, DecimalType) @@ -579,27 +584,13 @@ case class IntegralDivide( case d: DecimalType => d.asIntegral.asInstanceOf[Integral[Any]] } - val _div = - (x: Any, y: Any) => { - val res = integral.quot(x, y) - if (res == null) { - null - } else { - integral.asInstanceOf[Integral[Any]].toLong(res) - } + (x: Any, y: Any) => { + val res = integral.quot(x, y) + if (res == null) { + null + } else { + integral.asInstanceOf[Integral[Any]].toLong(res) } - - left.dataType match { - case LongType if failOnError => - (x: Any, y: Any) => { - if (x == Long.MinValue && y == -1) { - throw QueryExecutionErrors.overflowInIntegralDivideError() - } - _div(x, y) - } - - case _ => - _div } } From 64f1431195e666227094a456f9c4d20fff181004 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 21 Apr 2021 00:32:12 +0800 Subject: [PATCH 5/5] revise --- .../org/apache/spark/sql/catalyst/expressions/arithmetic.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index e95231db4e96..954a4b9fc138 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -584,7 +584,7 @@ case class IntegralDivide( case d: DecimalType => d.asIntegral.asInstanceOf[Integral[Any]] } - (x: Any, y: Any) => { + (x, y) => { val res = integral.quot(x, y) if (res == null) { null