Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address comment
  • Loading branch information
wangyum committed May 28, 2020
commit 1bdff9584a09d42ba27f9742bf5a6e8a4c5c74d3
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,11 @@ trait DivModLike extends BinaryArithmetic {

override def nullable: Boolean = true

final override def eval(input: InternalRow): Any = {
val input2 = right.eval(input)
if (input2 == null || input2 == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see the difference now. Previously we can skip evaluating input1 if input2 is 0. Can we change it back and add comment to explain it? sorry for the back and forth!

final override def nullSafeEval(input1: Any, input2: Any): Any = {
if (input2 == 0) {
null
} else {
val input1 = left.eval(input)
if (input1 == null) {
null
} else {
evalOperation(input1, input2)
}
evalOperation(input1, input2)
}
}

Expand Down Expand Up @@ -516,24 +510,18 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {

override def nullable: Boolean = true

override def eval(input: InternalRow): Any = {
val input2 = right.eval(input)
if (input2 == null || input2 == 0) {
override def nullSafeEval(input1: Any, input2: Any): Any = {
if (input2 == 0) {
null
} else {
val input1 = left.eval(input)
if (input1 == null) {
null
} else {
input1 match {
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
}
input1 match {
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression])

// Do not check these expressions, because these expressions extend NullIntolerant
// and override the eval function.
val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod])

val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction()
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
.filterNot(c => ignoreSet.exists(_.getName.equals(c)))
.map(name => Utils.classForName(name))
.filterNot(classOf[NonSQLExpression].isAssignableFrom)

Expand All @@ -180,8 +175,9 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
superClass.getMethod("eval", classOf[InternalRow])
val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz)
if (isEvalOverrode && isNullIntolerantMixedIn) {
fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " +
s"or add ${clazz.getName} in the ignoreSet of this test.")
fail(s"${clazz.getName} overrode the eval method and extended " +
s"${classOf[NullIntolerant].getSimpleName}, which may be incorrect. " +
s"You may need to override the nullSafeEval method.")
} else if (!isEvalOverrode && !isNullIntolerantMixedIn) {
fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.")
} else {
Expand Down