Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
divide 0 return null as other types
  • Loading branch information
yaooqinn committed Dec 24, 2019
commit 67645d4d3415735fcd464338ed6d0dcd52c32088
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType)
case CalendarIntervalType =>
DivideInterval(sum.cast(resultType), count.cast(DoubleType), false)
DivideInterval(sum.cast(resultType), count.cast(DoubleType))
case _ =>
sum.cast(resultType) / count.cast(resultType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,27 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression)

protected val checkOverflow: Boolean = SQLConf.get.ansiEnabled

protected def operation(interval: CalendarInterval, num: Double): CalendarInterval

protected val operationName: String

override def left: Expression = interval
override def right: Expression = num

override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType, DoubleType)
override def dataType: DataType = CalendarIntervalType

override def nullable: Boolean = true
}

case class MultiplyInterval(interval: Expression, num: Expression)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this also new in Spark 3.0? If it is I think it's OK to always follow the ansi behavior regardless the ansi flag.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is

extends IntervalNumOperation(interval, num) {

override def prettyName: String = "multiply_interval"

override def nullSafeEval(interval: Any, num: Any): Any = {
try {
operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
if (checkOverflow) {
multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
} else {
safeMultiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
}
} catch {
case _: ArithmeticException if !checkOverflow => null
}
Expand All @@ -141,6 +147,7 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (interval, num) => {
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
val operationName = if (checkOverflow) "multiply" else "safeMultiply"
s"""
try {
${ev.value} = $iu.$operationName($interval, $num);
Expand All @@ -154,31 +161,47 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression)
"""
})
}

override def prettyName: String = operationName.stripPrefix("safe").toLowerCase() + "_interval"
}

case class MultiplyInterval(interval: Expression, num: Expression)
case class DivideInterval(interval: Expression, num: Expression)
extends IntervalNumOperation(interval, num) {

override protected def operation(interval: CalendarInterval, num: Double): CalendarInterval = {
if (checkOverflow) multiply(interval, num) else safeMultiply(interval, num)
}
override def prettyName: String = "divide_interval"

override protected val operationName: String = if (checkOverflow) "multiply" else "safeMultiply"
}

case class DivideInterval(
interval: Expression,
num: Expression,
override val checkOverflow: Boolean = SQLConf.get.ansiEnabled)
extends IntervalNumOperation(interval, num) {

override protected def operation(interval: CalendarInterval, num: Double): CalendarInterval = {
if (checkOverflow) divide(interval, num) else safeDivide(interval, num)
override def nullSafeEval(interval: Any, num: Any): Any = {
try {
if (num == 0) return null
if (checkOverflow) {
divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
} else {
safeDivide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
}
} catch {
case _: ArithmeticException if !checkOverflow => null
}
}

override protected val operationName: String = if (checkOverflow) "divide" else "safeDivide"
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (interval, num) => {
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
val operationName = if (checkOverflow) "divide" else "safeDivide"
s"""
try {
if ($num == 0) {
${ev.isNull} = true;
} else {
${ev.value} = $iu.$operationName($interval, $num);
}
} catch (ArithmeticException e) {
if ($checkOverflow) {
throw e;
} else {
${ev.isNull} = true;
}
}
"""
})
}
}

// scalastyle:off line.size.limit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,9 @@ struct<divide_interval(subtracttimestamps(TIMESTAMP '2019-10-15 00:00:00', TIMES
-- !query 25
select interval '2 seconds' / 0
-- !query 25 schema
struct<>
struct<divide_interval(INTERVAL '2 seconds', CAST(0 AS DOUBLE)):interval>
-- !query 25 output
java.lang.ArithmeticException
divide by zero
NULL


-- !query 26
Expand Down