Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Comparator

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{
CodegenFallback, CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -145,46 +143,42 @@ case class ArrayContains(left: Expression, right: Expression)
}
}

override def nullable: Boolean = false
override def nullable: Boolean = {
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we need to swap left and right?


override def eval(input: InternalRow): Boolean = {
val arr = left.eval(input)
if (arr == null) {
false
} else {
val value = right.eval(input)
if (value == null) {
false
} else {
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == value) return true
)
false
override def nullSafeEval(arr: Any, value: Any): Any = {
var hasNull = false
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == null) {
hasNull = true
} else if (v == value) {
return true
}
)
if (hasNull) {
null
} else {
false
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val arrGen = left.gen(ctx)
val elementGen = right.gen(ctx)
val i = ctx.freshName("i")
val getValue = ctx.getValue(arrGen.primitive, right.dataType, i)
s"""
${arrGen.code}
boolean ${ev.isNull} = false;
boolean ${ev.primitive} = false;
if (!${arrGen.isNull}) {
${elementGen.code}
if (!${elementGen.isNull}) {
for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) {
if (${ctx.genEqual(right.dataType, elementGen.primitive, getValue)}) {
${ev.primitive} = true;
break;
}
}
nullSafeCodeGen(ctx, ev, (arr, value) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to call {defineCodeGen}?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nullSafeCodeGen is similar to defineCodeGen

val i = ctx.freshName("i")
val getValue = ctx.getValue(arr, right.dataType, i)
s"""
for (int $i = 0; $i < $arr.numElements(); $i ++) {
if ($arr.isNullAt($i)) {
${ev.isNull} = true;
} else if (${ctx.genEqual(right.dataType, value, getValue)}) {
${ev.isNull} = false;
${ev.primitive} = true;
break;
}
}
"""
})
}

override def prettyName: String = "array_contains"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -103,6 +101,8 @@ case class Not(child: Expression)
case class In(value: Expression, list: Seq[Expression]) extends Predicate
with ImplicitCastInputTypes {

require(list != null, "list should not be null")

override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)

override def checkInputDataTypes(): TypeCheckResult = {
Expand All @@ -116,12 +116,31 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate

override def children: Seq[Expression] = value +: list

override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"

override def eval(input: InternalRow): Any = {
val evaluatedValue = value.eval(input)
list.exists(e => e.eval(input) == evaluatedValue)
if (evaluatedValue == null) {
null
} else {
var hasNull = false
list.foreach { e =>
val v = e.eval(input)
if (v == evaluatedValue) {
return true
} else if (v == null) {
hasNull = true
}
}
if (hasNull) {
null
} else {
false
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
Expand All @@ -131,16 +150,21 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
s"""
if (!${ev.primitive}) {
${x.code}
if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) {
if (${x.isNull}) {
${ev.isNull} = true;
} else if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) {
${ev.isNull} = false;
${ev.primitive} = true;
}
}
""").mkString("\n")
s"""
${valueGen.code}
boolean ${ev.primitive} = false;
boolean ${ev.isNull} = false;
$listCode
boolean ${ev.isNull} = ${valueGen.isNull};
if (!${ev.isNull}) {
$listCode
}
"""
}
}
Expand All @@ -151,11 +175,22 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
*/
case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {

override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
require(hset != null, "hset could not be null")

override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"

override def eval(input: InternalRow): Any = {
hset.contains(child.eval(input))
@transient private[this] lazy val hasNull: Boolean = hset.contains(null)

override def nullable: Boolean = child.nullable || hasNull

protected override def nullSafeEval(value: Any): Any = {
if (hset.contains(value)) {
true
} else if (hasNull) {
null
} else {
false
}
}

def getHSet(): Set[Any] = hset
Expand All @@ -166,12 +201,20 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
val childGen = child.gen(ctx)
ctx.references += this
val hsetTerm = ctx.freshName("hset")
val hasNullTerm = ctx.freshName("hasNull")
ctx.addMutableState(setName, hsetTerm,
s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();")
ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
s"""
${childGen.code}
boolean ${ev.isNull} = false;
boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive});
boolean ${ev.isNull} = ${childGen.isNull};
boolean ${ev.primitive} = false;
if (!${ev.isNull}) {
${ev.primitive} = $hsetTerm.contains(${childGen.primitive});
if (!${ev.primitive} && $hasNullTerm) {
${ev.isNull} = true;
}
}
"""
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,12 +395,6 @@ object ConstantFolding extends Rule[LogicalPlan] {

// Fold expressions that are foldable.
case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType)

// Fold "literal in (item1, item2, ..., literal, ...)" into true directly.
case In(Literal(v, _), list) if list.exists {
case Literal(candidate, _) if candidate == v => true
case _ => false
} => Literal.create(true, BooleanType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,20 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a2 = Literal.create(Seq(null), ArrayType(LongType))
val a3 = Literal.create(null, ArrayType(StringType))

checkEvaluation(ArrayContains(a0, Literal(1)), true)
checkEvaluation(ArrayContains(a0, Literal(0)), false)
checkEvaluation(ArrayContains(a0, Literal(null)), false)
checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null)

checkEvaluation(ArrayContains(a1, Literal("")), true)
checkEvaluation(ArrayContains(a1, Literal(null)), false)
checkEvaluation(ArrayContains(a1, Literal("a")), null)
checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null)

checkEvaluation(ArrayContains(a2, Literal(null)), false)
checkEvaluation(ArrayContains(a2, Literal(1L)), null)
checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null)

checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.immutable.HashSet

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -119,21 +118,31 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(null, null, null) :: Nil)

test("IN") {
checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal(1), Literal(2))), null)
checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal.create(null, IntegerType))),
null)
checkEvaluation(In(Literal(1), Seq(Literal.create(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal.create(null, IntegerType))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal.create(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
checkEvaluation(
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
true)

checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
val ns = Literal.create(null, StringType)
checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
checkEvaluation(In(ns, Seq(ns)), null)
checkEvaluation(In(Literal("a"), Seq(ns)), null)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)

val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.map { t =>
val dataGen = RandomDataGenerator.forType(t, nullable = false).get
val dataGen = RandomDataGenerator.forType(t, nullable = true).get
val inputData = Seq.fill(10) {
val value = dataGen.apply()
value match {
Expand All @@ -142,9 +151,17 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
case _ => value
}
}
val input = inputData.map(Literal(_))
checkEvaluation(In(input(0), input.slice(1, 10)),
inputData.slice(1, 10).contains(inputData(0)))
val input = inputData.map(Literal.create(_, t))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
true
} else if (inputData.slice(1, 10).contains(null)) {
null
} else {
false
}
checkEvaluation(In(input(0), input.slice(1, 10)), expected)
}
}

Expand All @@ -158,15 +175,15 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InSet(one, hS), true)
checkEvaluation(InSet(two, hS), true)
checkEvaluation(InSet(two, nS), true)
checkEvaluation(InSet(nl, nS), true)
checkEvaluation(InSet(three, hS), false)
checkEvaluation(InSet(three, nS), false)
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
checkEvaluation(InSet(three, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)

val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.map { t =>
val dataGen = RandomDataGenerator.forType(t, nullable = false).get
val dataGen = RandomDataGenerator.forType(t, nullable = true).get
val inputData = Seq.fill(10) {
val value = dataGen.apply()
value match {
Expand All @@ -176,8 +193,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
val input = inputData.map(Literal(_))
checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet),
inputData.slice(1, 10).contains(inputData(0)))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
true
} else if (inputData.slice(1, 10).contains(null)) {
null
} else {
false
}
checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,29 +250,14 @@ class ConstantFoldingSuite extends PlanTest {
}

test("Constant folding test: Fold In(v, list) into true or false") {
var originalQuery =
val originalQuery =
testRelation
.select('a)
.where(In(Literal(1), Seq(Literal(1), Literal(2))))

var optimized = Optimize.execute(originalQuery.analyze)

var correctAnswer =
testRelation
.select('a)
.where(Literal(true))
.analyze

comparePlans(optimized, correctAnswer)

originalQuery =
testRelation
.select('a)
.where(In(Literal(1), Seq(Literal(1), 'a.attr)))

optimized = Optimize.execute(originalQuery.analyze)
val optimized = Optimize.execute(originalQuery.analyze)

correctAnswer =
val correctAnswer =
testRelation
.select('a)
.where(Literal(true))
Expand Down
Loading