Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Reuse exactMathMethod()
  • Loading branch information
MaxGekk committed Mar 10, 2021
commit 0a202f3490bbb84e19dc041d0c73ef248304498a
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,6 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
def calendarIntervalMethod: String =
sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode")

/** Name of the function for this expression on [[DayTimeIntervalType]] and
* [[YearMonthIntervalType]] types. */
def intervalMethod: String =
sys.error("BinaryArithmetics must override either intervalMethod or genCode")

// Name of the function for the exact version of this expression in [[Math]].
// If the option "spark.sql.ansi.enabled" is enabled and there is corresponding
// function in [[Math]], the exact function will be called instead of evaluation with [[symbol]].
Expand All @@ -198,8 +193,11 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)")
case DayTimeIntervalType | YearMonthIntervalType =>
assert(exactMathMethod.isDefined,
s"The expression '$nodeName' must override the exactMathMethod() method " +
s"if it is supposed to operate over interval types.")
val mathClass = classOf[Math].getName
defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathClass.${intervalMethod}($eval1, $eval2)")
defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathClass.${exactMathMethod.get}($eval1, $eval2)")
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
Expand Down Expand Up @@ -272,7 +270,6 @@ case class Add(
override def decimalMethod: String = "$plus"

override def calendarIntervalMethod: String = if (failOnError) "addExact" else "add"
override def intervalMethod: String = "addExact"

private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError)

Expand Down Expand Up @@ -316,7 +313,6 @@ case class Subtract(
override def decimalMethod: String = "$minus"

override def calendarIntervalMethod: String = if (failOnError) "subtractExact" else "subtract"
override def intervalMethod: String = "subtractExact"

private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError)

Expand Down