Skip to content

Commit f490471

Browse files
SEONGJAEGONGpasar6987
authored andcommitted
[SPARK-52659][SQL] Misleading modulo error message in ansi mode
### What changes were proposed in this pull request? 1. **Updated error condition**: Changed `MOD_BY_ZERO` to `REMAINDER_BY_ZERO` in `error-conditions.json` with an appropriate error message "Remainder by zero" instead of "Mod by zero" 2. **Updated test cases**: Modified `ArithmeticExpressionSuite.scala` to separate division and remainder operation tests, ensuring that remainder operations (`Remainder` and `Pmod`) expect "Remainder by zero" error message instead of "Division by zero" The mod function now correctly throws `REMAINDER_BY_ZERO` error with message "Remainder by zero. Use `try_mod` to tolerate divisor being 0 and return NULL instead." instead of the misleading `DIVIDE_BY_ZERO` error. ### Why are the changes needed? In ANSI mode, when executing `spark.sql("select mod(10,0)")`, the system incorrectly throws a `DIVIDE_BY_ZERO` error with message "Division by zero". ### Does this PR introduce _any_ user-facing change? Yes. This PR changes the user-facing error message for modulo operations in ANSI mode. **Before:** org.apache.spark.SparkArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use try_divide to tolerate divisor being 0 and return NULL instead. **After:** org.apache.spark.SparkArithmeticException: [REMAINDER_BY_ZERO] Remainder by zero. Use try_mod to tolerate divisor being 0 and return NULL instead. ### How was this patch tested? Updated existing unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. This is my first pull request. I'm writing this pull request after reviewing the documentation as much as possible, but if there's anything wrong, please let me know. Closes #51378 from pasar6987/Misleading-modulo-error-message-in-ANSI-mode. Lead-authored-by: 공성재 <[email protected]> Co-authored-by: 공성재 <[email protected]> Co-authored-by: pasar6987 <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent 3990b0f commit f490471

File tree

7 files changed

+107
-23
lines changed

7 files changed

+107
-23
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4920,6 +4920,12 @@
49204920
],
49214921
"sqlState" : "42601"
49224922
},
4923+
"REMAINDER_BY_ZERO" : {
4924+
"message" : [
4925+
"Remainder by zero. Use `try_mod` to tolerate divisor being 0 and return NULL instead. If necessary set <config> to \"false\" to bypass this error."
4926+
],
4927+
"sqlState" : "22012"
4928+
},
49234929
"RENAME_SRC_PATH_NOT_FOUND" : {
49244930
"message" : [
49254931
"Failed to rename as <sourcePath> was not found."

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,13 @@ trait DivModLike extends BinaryArithmetic {
654654
} else {
655655
if (isZero(input2)) {
656656
// when we reach here, failOnError must be true.
657-
throw QueryExecutionErrors.divideByZeroError(getContextOrNull())
657+
val context = getContextOrNull()
658+
val ex = this match {
659+
case _: Remainder => QueryExecutionErrors.remainderByZeroError(context)
660+
case _: Pmod => QueryExecutionErrors.remainderByZeroError(context)
661+
case _ => QueryExecutionErrors.divideByZeroError(context)
662+
}
663+
throw ex
658664
}
659665
if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) {
660666
throw QueryExecutionErrors.overflowInIntegralDivideError(getContextOrNull())
@@ -669,6 +675,15 @@ trait DivModLike extends BinaryArithmetic {
669675
/**
670676
* Special case handling due to division/remainder by 0 => null or ArithmeticException.
671677
*/
678+
protected def divideByZeroErrorCode(ctx: CodegenContext): String = {
679+
val errorContextCode = getContextOrNullCode(ctx, failOnError)
680+
this match {
681+
case _: Remainder => s"QueryExecutionErrors.remainderByZeroError($errorContextCode)"
682+
case _: Pmod => s"QueryExecutionErrors.remainderByZeroError($errorContextCode)"
683+
case _ => s"QueryExecutionErrors.divideByZeroError($errorContextCode)"
684+
}
685+
}
686+
672687
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
673688
val eval1 = left.genCode(ctx)
674689
val eval2 = right.genCode(ctx)
@@ -706,7 +721,7 @@ trait DivModLike extends BinaryArithmetic {
706721
// evaluate right first as we have a chance to skip left if right is 0
707722
if (!left.nullable && !right.nullable) {
708723
val divByZero = if (failOnError) {
709-
s"throw QueryExecutionErrors.divideByZeroError($errorContextCode);"
724+
s"throw ${divideByZeroErrorCode(ctx)};"
710725
} else {
711726
s"${ev.isNull} = true;"
712727
}
@@ -724,7 +739,7 @@ trait DivModLike extends BinaryArithmetic {
724739
} else {
725740
val nullOnErrorCondition = if (failOnError) "" else s" || $isZero"
726741
val failOnErrorBranch = if (failOnError) {
727-
s"if ($isZero) throw QueryExecutionErrors.divideByZeroError($errorContextCode);"
742+
s"if ($isZero) throw ${divideByZeroErrorCode(ctx)};"
728743
} else {
729744
""
730745
}
@@ -1047,7 +1062,7 @@ case class Pmod(
10471062
} else {
10481063
if (isZero(input2)) {
10491064
// when we reach here, failOnError must bet true.
1050-
throw QueryExecutionErrors.divideByZeroError(getContextOrNull())
1065+
throw QueryExecutionErrors.remainderByZeroError(getContextOrNull())
10511066
}
10521067
pmodFunc(input1, input2)
10531068
}
@@ -1104,7 +1119,7 @@ case class Pmod(
11041119
// evaluate right first as we have a chance to skip left if right is 0
11051120
if (!left.nullable && !right.nullable) {
11061121
val divByZero = if (failOnError) {
1107-
s"throw QueryExecutionErrors.divideByZeroError($errorContext);"
1122+
s"throw QueryExecutionErrors.remainderByZeroError($errorContext);"
11081123
} else {
11091124
s"${ev.isNull} = true;"
11101125
}
@@ -1121,7 +1136,7 @@ case class Pmod(
11211136
} else {
11221137
val nullOnErrorCondition = if (failOnError) "" else s" || $isZero"
11231138
val failOnErrorBranch = if (failOnError) {
1124-
s"if ($isZero) throw QueryExecutionErrors.divideByZeroError($errorContext);"
1139+
s"if ($isZero) throw QueryExecutionErrors.remainderByZeroError($errorContext);"
11251140
} else {
11261141
""
11271142
}

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
205205
summary = getSummary(context))
206206
}
207207

208+
def remainderByZeroError(context: QueryContext): ArithmeticException = {
209+
new SparkArithmeticException(
210+
errorClass = "REMAINDER_BY_ZERO",
211+
messageParameters = Map("config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
212+
context = Array(context),
213+
summary = getSummary(context))
214+
}
215+
208216
def intervalDividedByZeroError(context: QueryContext): ArithmeticException = {
209217
new SparkArithmeticException(
210218
errorClass = "INTERVAL_DIVIDED_BY_ZERO",

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
463463
}
464464
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
465465
checkExceptionInExpression[ArithmeticException](
466-
Remainder(left, Literal(convert(0))), "Division by zero")
466+
Remainder(left, Literal(convert(0))), "Remainder by zero")
467467
}
468468
}
469469
checkEvaluation(Remainder(positiveShortLit, positiveShortLit), 0.toShort)
@@ -567,7 +567,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
567567
}
568568
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
569569
checkExceptionInExpression[ArithmeticException](
570-
Pmod(left, Literal(convert(0))), "Division by zero")
570+
Pmod(left, Literal(convert(0))), "Remainder by zero")
571571
}
572572
}
573573
checkEvaluation(Pmod(Literal(-7), Literal(3)), 2)
@@ -873,12 +873,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
873873

874874
test("SPARK-33008: division by zero on divide-like operations returns incorrect result") {
875875
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
876-
val operators: Seq[((Expression, Expression) => Expression, ((Int => Any) => Unit) => Unit)] =
876+
// Test division operations
877+
val divideOperators: Seq[
878+
((Expression, Expression) => Expression, ((Int => Any) => Unit) => Unit)
879+
] =
877880
Seq((Divide(_, _), testDecimalAndDoubleType),
878-
(IntegralDivide(_, _), testDecimalAndLongType),
879-
(Remainder(_, _), testNumericDataTypes),
880-
(Pmod(_, _), testNumericDataTypes))
881-
operators.foreach { case (operator, testTypesFn) =>
881+
(IntegralDivide(_, _), testDecimalAndLongType))
882+
divideOperators.foreach { case (operator, testTypesFn) =>
882883
testTypesFn { convert =>
883884
val one = Literal(convert(1))
884885
val zero = Literal(convert(0))
@@ -887,6 +888,22 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
887888
checkExceptionInExpression[ArithmeticException](operator(one, zero), "Division by zero")
888889
}
889890
}
891+
892+
// Test remainder operations
893+
val remainderOperators: Seq[
894+
((Expression, Expression) => Expression, ((Int => Any) => Unit) => Unit)
895+
] =
896+
Seq((Remainder(_, _), testNumericDataTypes),
897+
(Pmod(_, _), testNumericDataTypes))
898+
remainderOperators.foreach { case (operator, testTypesFn) =>
899+
testTypesFn { convert =>
900+
val one = Literal(convert(1))
901+
val zero = Literal(convert(0))
902+
checkEvaluation(operator(Literal.create(null, one.dataType), zero), null)
903+
checkEvaluation(operator(one, Literal.create(null, zero.dataType)), null)
904+
checkExceptionInExpression[ArithmeticException](operator(one, zero), "Remainder by zero")
905+
}
906+
}
890907
}
891908
}
892909

@@ -931,12 +948,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
931948

932949
test("SPARK-34920: error class") {
933950
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
934-
val operators: Seq[((Expression, Expression) => Expression, ((Int => Any) => Unit) => Unit)] =
951+
// Test division operations
952+
val divideOperators: Seq[
953+
((Expression, Expression) => Expression, ((Int => Any) => Unit) => Unit)
954+
] =
935955
Seq((Divide(_, _), testDecimalAndDoubleType),
936-
(IntegralDivide(_, _), testDecimalAndLongType),
937-
(Remainder(_, _), testNumericDataTypes),
938-
(Pmod(_, _), testNumericDataTypes))
939-
operators.foreach { case (operator, testTypesFn) =>
956+
(IntegralDivide(_, _), testDecimalAndLongType))
957+
divideOperators.foreach { case (operator, testTypesFn) =>
940958
testTypesFn { convert =>
941959
val one = Literal(convert(1))
942960
val zero = Literal(convert(0))
@@ -946,6 +964,23 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
946964
"Division by zero")
947965
}
948966
}
967+
968+
// Test remainder operations
969+
val remainderOperators: Seq[
970+
((Expression, Expression) => Expression, ((Int => Any) => Unit) => Unit)
971+
] =
972+
Seq((Remainder(_, _), testNumericDataTypes),
973+
(Pmod(_, _), testNumericDataTypes))
974+
remainderOperators.foreach { case (operator, testTypesFn) =>
975+
testTypesFn { convert =>
976+
val one = Literal(convert(1))
977+
val zero = Literal(convert(0))
978+
checkEvaluation(operator(Literal.create(null, one.dataType), zero), null)
979+
checkEvaluation(operator(one, Literal.create(null, zero.dataType)), null)
980+
checkExceptionInExpression[SparkArithmeticException](operator(one, zero),
981+
"Remainder by zero")
982+
}
983+
}
949984
}
950985
}
951986

sql/core/src/test/resources/sql-tests/results/decimalArithmeticOperations.sql.out

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct<>
3636
-- !query output
3737
org.apache.spark.SparkArithmeticException
3838
{
39-
"errorClass" : "DIVIDE_BY_ZERO",
39+
"errorClass" : "REMAINDER_BY_ZERO",
4040
"sqlState" : "22012",
4141
"messageParameters" : {
4242
"config" : "\"spark.sql.ansi.enabled\""
@@ -58,7 +58,7 @@ struct<>
5858
-- !query output
5959
org.apache.spark.SparkArithmeticException
6060
{
61-
"errorClass" : "DIVIDE_BY_ZERO",
61+
"errorClass" : "REMAINDER_BY_ZERO",
6262
"sqlState" : "22012",
6363
"messageParameters" : {
6464
"config" : "\"spark.sql.ansi.enabled\""

sql/core/src/test/resources/sql-tests/results/operators.sql.out

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ struct<>
496496
-- !query output
497497
org.apache.spark.SparkArithmeticException
498498
{
499-
"errorClass" : "DIVIDE_BY_ZERO",
499+
"errorClass" : "REMAINDER_BY_ZERO",
500500
"sqlState" : "22012",
501501
"messageParameters" : {
502502
"config" : "\"spark.sql.ansi.enabled\""
@@ -566,7 +566,7 @@ struct<>
566566
-- !query output
567567
org.apache.spark.SparkArithmeticException
568568
{
569-
"errorClass" : "DIVIDE_BY_ZERO",
569+
"errorClass" : "REMAINDER_BY_ZERO",
570570
"sqlState" : "22012",
571571
"messageParameters" : {
572572
"config" : "\"spark.sql.ansi.enabled\""
@@ -588,7 +588,7 @@ struct<>
588588
-- !query output
589589
org.apache.spark.SparkArithmeticException
590590
{
591-
"errorClass" : "DIVIDE_BY_ZERO",
591+
"errorClass" : "REMAINDER_BY_ZERO",
592592
"sqlState" : "22012",
593593
"messageParameters" : {
594594
"config" : "\"spark.sql.ansi.enabled\""

sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,26 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest
8585
callSitePattern = getCurrentClassCallSitePattern))
8686
}
8787

88+
test("REMAINDER_BY_ZERO: can't take modulo of an integer by zero") {
89+
checkError(
90+
exception = intercept[SparkArithmeticException] {
91+
sql("select 6 % 0").collect()
92+
},
93+
condition = "REMAINDER_BY_ZERO",
94+
sqlState = "22012",
95+
parameters = Map("config" -> ansiConf),
96+
context = ExpectedContext(fragment = "6 % 0", start = 7, stop = 11))
97+
98+
checkError(
99+
exception = intercept[SparkArithmeticException] {
100+
sql("select pmod(6, 0)").collect()
101+
},
102+
condition = "REMAINDER_BY_ZERO",
103+
sqlState = "22012",
104+
parameters = Map("config" -> ansiConf),
105+
context = ExpectedContext(fragment = "pmod(6, 0)", start = 7, stop = 16))
106+
}
107+
88108
test("INTERVAL_DIVIDED_BY_ZERO: interval divided by zero") {
89109
checkError(
90110
exception = intercept[SparkArithmeticException] {

0 commit comments

Comments
 (0)