Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ abstract class UnaryExpression extends Expression {

/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
* null, and if not null, use `f` to generate the expression.
*
* As an example, the following does a boolean inversion (i.e. NOT).
* {{{
Expand All @@ -340,7 +340,7 @@ abstract class UnaryExpression extends Expression {

/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
* null, and if not null, use `f` to generate the expression.
*
* @param f function that accepts the non-null evaluation result name of child and returns Java
* code to compute the output.
Expand All @@ -349,20 +349,23 @@ abstract class UnaryExpression extends Expression {
ctx: CodegenContext,
ev: ExprCode,
f: String => String): String = {
val eval = child.gen(ctx)
val childGen = child.gen(ctx)
val resultCode = f(childGen.value)

if (nullable) {
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${eval.isNull}) {
${f(eval.value)}
}
$nullSafeEval
"""
} else {
ev.isNull = "false"
eval.code + s"""
s"""
${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${f(eval.value)}
$resultCode
"""
}
}
Expand Down Expand Up @@ -440,29 +443,31 @@ abstract class BinaryExpression extends Expression {
ctx: CodegenContext,
ev: ExprCode,
f: (String, String) => String): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.value, eval2.value)
val leftGen = left.gen(ctx)
val rightGen = right.gen(ctx)
val resultCode = f(leftGen.value, rightGen.value)

if (nullable) {
val nullSafeEval =
leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) {
rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
}

s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if (!${eval2.isNull}) {
$resultCode
} else {
${ev.isNull} = true;
}
}
$nullSafeEval
"""

} else {
ev.isNull = "false"
s"""
${eval1.code}
${eval2.code}
${leftGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode
"""
Expand Down Expand Up @@ -527,7 +532,7 @@ abstract class TernaryExpression extends Expression {

/**
* Default behavior of evaluation according to the default nullability of TernaryExpression.
* If subclass of BinaryExpression override nullable, probably should also override this.
* If subclass of TernaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
Expand All @@ -553,11 +558,11 @@ abstract class TernaryExpression extends Expression {
sys.error(s"BinaryExpressions must override either eval or nullSafeEval")

/**
* Short hand for generating binary evaluation code.
* Short hand for generating ternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts two variable names and returns Java code to compute the output.
* @param f accepts three variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
Expand All @@ -569,41 +574,46 @@ abstract class TernaryExpression extends Expression {
}

/**
* Short hand for generating binary evaluation code.
* Short hand for generating ternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f function that accepts the 2 non-null evaluation result names of children
* @param f function that accepts the 3 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): String = {
val evals = children.map(_.gen(ctx))
val resultCode = f(evals(0).value, evals(1).value, evals(2).value)
val leftGen = children(0).gen(ctx)
val midGen = children(1).gen(ctx)
val rightGen = children(2).gen(ctx)
val resultCode = f(leftGen.value, midGen.value, rightGen.value)

if (nullable) {
val nullSafeEval =
leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) {
midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) {
rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
}
}

s"""
${evals(0).code}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${evals(0).isNull}) {
${evals(1).code}
if (!${evals(1).isNull}) {
${evals(2).code}
if (!${evals(2).isNull}) {
${ev.isNull} = false; // resultCode could change nullability
$resultCode
}
}
}
$nullSafeEval
"""
} else {
ev.isNull = "false"
s"""
${evals(0).code}
${evals(1).code}
${evals(2).code}
${leftGen.code}
${midGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,17 +355,37 @@ class CodegenContext {
}

/**
* Generates code for greater of two expressions.
*
* @param dataType data type of the expressions
* @param c1 name of the variable of expression 1's output
* @param c2 name of the variable of expression 2's output
*/
* Generates code for greater of two expressions.
*
* @param dataType data type of the expressions
* @param c1 name of the variable of expression 1's output
* @param c2 name of the variable of expression 2's output
*/
def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match {
case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2"
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
}

/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
*
* @param nullable used to decide whether we should add null check or not.
* @param isNull the code to check if the input is null.
* @param execute the code that should only be executed when the input is not null.
*/
def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = {
if (nullable) {
s"""
if (!$isNull) {
$execute
}
"""
} else {
"\n" + execute
}
}

/**
* List of java data types that have special accessors and setters in [[InternalRow]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
ev.isNull = "false"
val childrenHash = children.map { child =>
val childGen = child.gen(ctx)
childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, ev.value, ctx)
}
}.mkString("\n")
Expand All @@ -338,18 +338,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
"""
}

private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = {
if (nullable) {
s"""
if (!$isNull) {
$execution
}
"""
} else {
"\n" + execution
}
}

private def nullSafeElementHash(
input: String,
index: String,
Expand All @@ -359,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
ctx: CodegenContext): String = {
val element = ctx.freshName("element")

generateNullCheck(nullable, s"$input.isNullAt($index)") {
ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
s"""
final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
${computeHash(element, elementType, result, ctx)}
Expand Down