Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor
  • Loading branch information
gengliangwang committed Apr 20, 2021
commit a6ad9bfdf5059f88b17d446f2dbfb31a6b065225
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,8 @@ trait DivModLike extends BinaryArithmetic {

protected def decimalToDataTypeCodeGen(decimalResult: String): String = decimalResult

// When it is an integral divide, we need to check whether overflow happens in ANSI mode.
protected def isIntegralDivide: Boolean = false
// Whether we should check overflow or not in ANSI mode.
protected def checkDivideOverflow: Boolean = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this simply left.dataType == LongType? For Divide, the child is either double or decimal, so won't hit this branch.

Under the hood, what we really want to handle is Long.MinValue / -1, so check left.dataType == LongType looks reasonable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, Remainder's left can be Long as well


override def nullable: Boolean = true

Expand All @@ -428,6 +428,9 @@ trait DivModLike extends BinaryArithmetic {
// when we reach here, failOnError must bet true.
throw QueryExecutionErrors.divideByZeroError
}
if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) {
throw QueryExecutionErrors.overflowInIntegralDivideError()
}
evalOperation(input1, input2)
}
}
Expand All @@ -453,14 +456,13 @@ trait DivModLike extends BinaryArithmetic {
} else {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
val checkIntegralDivideOverflow = left.dataType match {
case LongType if failOnError && isIntegralDivide =>
s"""
|if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1)
| throw QueryExecutionErrors.overflowInIntegralDivideError();
|""".stripMargin

case _ => ""
val checkIntegralDivideOverflow = if (checkDivideOverflow) {
s"""
|if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1)
| throw QueryExecutionErrors.overflowInIntegralDivideError();
|""".stripMargin
} else {
""
}

// evaluate right first as we have a chance to skip left if right is 0
Expand Down Expand Up @@ -561,7 +563,10 @@ case class IntegralDivide(

def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)

override def isIntegralDivide: Boolean = true
override def checkDivideOverflow: Boolean = left.dataType match {
case LongType if failOnError => true
case _ => false
}

override def inputType: AbstractDataType = TypeCollection(LongType, DecimalType)

Expand All @@ -579,27 +584,13 @@ case class IntegralDivide(
case d: DecimalType =>
d.asIntegral.asInstanceOf[Integral[Any]]
}
val _div =
(x: Any, y: Any) => {
val res = integral.quot(x, y)
if (res == null) {
null
} else {
integral.asInstanceOf[Integral[Any]].toLong(res)
}
(x: Any, y: Any) => {
val res = integral.quot(x, y)
if (res == null) {
null
} else {
integral.asInstanceOf[Integral[Any]].toLong(res)
}

left.dataType match {
case LongType if failOnError =>
(x: Any, y: Any) => {
if (x == Long.MinValue && y == -1) {
throw QueryExecutionErrors.overflowInIntegralDivideError()
}
_div(x, y)
}

case _ =>
_div
}
}

Expand Down