Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix
  • Loading branch information
wangyum committed May 26, 2020
commit b127c4134ea5c75b10093d53e8874de1ff54e807
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ package org.apache.spark.sql.expressions
import scala.collection.parallel.immutable.ParVector

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{NonSQLExpression, _}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -159,73 +158,37 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
}
}

test("Check whether should extend NullIntolerant") {
// Only check expressions extended from these expressions
val parentExpressionNames = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
classOf[TernaryExpression], classOf[QuaternaryExpression],
classOf[SeptenaryExpression]).map(_.getName)
// Do not check these expressions
val whiteList = Seq(
classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod],
classOf[CheckOverflow], classOf[NormalizeNaNAndZero], classOf[InSet],
classOf[PrintToStderr], classOf[CodegenFallbackExpression]).map(_.getName)

spark.sessionState.functionRegistry.listFunction()
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
.filterNot(c => whiteList.exists(_.equals(c))).foreach { className =>
if (needToCheckNullIntolerant(className)) {
val evalExist = checkIfEvalOverrode(className)
val nullIntolerantExist = checkIfNullIntolerantMixedIn(className)
if (evalExist && nullIntolerantExist) {
fail(s"$className should not extend ${classOf[NullIntolerant].getSimpleName}")
} else if (!evalExist && !nullIntolerantExist) {
fail(s"$className should extend ${classOf[NullIntolerant].getSimpleName}")
} else {
assert((!evalExist && nullIntolerantExist) || (evalExist && !nullIntolerantExist))
}
}
}
test("Check whether SQL expressions should extend NullIntolerant") {
// Only check expressions extended from these expressions because these expressions are
// NullIntolerant by default.
val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression])

def needToCheckNullIntolerant(className: String): Boolean = {
var clazz: Class[_] = Utils.classForName(className)
val isNonSQLExpr =
clazz.getInterfaces.exists(_.getName.equals(classOf[NonSQLExpression].getName))
var checkNullIntolerant: Boolean = false
while (!checkNullIntolerant && clazz.getSuperclass != null) {
checkNullIntolerant = parentExpressionNames.exists(_.equals(clazz.getSuperclass.getName))
if (!checkNullIntolerant) {
clazz = clazz.getSuperclass
}
}
checkNullIntolerant && !isNonSQLExpr
}
// Do not check these expressions, because these expressions extend NullIntolerant
// and override the eval function.
val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod])

def checkIfNullIntolerantMixedIn(className: String): Boolean = {
val nullIntolerantName = classOf[NullIntolerant].getName
var clazz: Class[_] = Utils.classForName(className)
var nullIntolerantMixedIn = false
while (!nullIntolerantMixedIn && !parentExpressionNames.exists(_.equals(clazz.getName))) {
nullIntolerantMixedIn = clazz.getInterfaces.exists(_.getName.equals(nullIntolerantName)) ||
clazz.getInterfaces.exists { i =>
Utils.classForName(i.getName).getInterfaces.exists(_.getName.equals(nullIntolerantName))
}
if (!nullIntolerantMixedIn) {
clazz = clazz.getSuperclass
}
}
nullIntolerantMixedIn
}

def checkIfEvalOverrode(className: String): Boolean = {
var clazz: Class[_] = Utils.classForName(className)
var evalOverrode: Boolean = false
while (!evalOverrode && !parentExpressionNames.exists(_.equals(clazz.getName))) {
evalOverrode = clazz.getDeclaredMethods.exists(_.getName.equals("eval"))
if (!evalOverrode) {
clazz = clazz.getSuperclass
val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction()
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
.filterNot(c => ignoreSet.exists(_.getName.equals(c)))
.map(name => Utils.classForName(name))
.filterNot(classOf[NonSQLExpression].isAssignableFrom)

exprTypesToCheck.foreach { superClass =>
candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz =>
val isEvalOverrode = clazz.getMethod("eval", classOf[InternalRow]) !=
superClass.getMethod("eval", classOf[InternalRow])
val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz)
if (isEvalOverrode && isNullIntolerantMixedIn) {
fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " +
s"or add ${clazz.getName} in the ignoreSet of this test.")
} else if (!isEvalOverrode && !isNullIntolerantMixedIn) {
fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.")
} else {
assert((!isEvalOverrode && isNullIntolerantMixedIn) ||
(isEvalOverrode && !isNullIntolerantMixedIn))
}
}
evalOverrode
}
}
}