Skip to content

Commit c1da4d4

Browse files
cloud-fandavies
authored andcommitted
[SPARK-13093] [SQL] improve null check in nullSafeCodeGen for unary, binary and ternary expression
The current implementation is sub-optimal: * If an expression is always nullable, e.g. `Unhex`, we can still remove null check for children if they are not nullable. * If an expression has some non-nullable children, we can still remove null check for these children and keep null check for others. This PR improves this by making the null check elimination more fine-grained. Author: Wenchen Fan <[email protected]> Closes #10987 from cloud-fan/null-check.
1 parent 5a8b978 commit c1da4d4

File tree

3 files changed

+85
-67
lines changed

3 files changed

+85
-67
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ abstract class UnaryExpression extends Expression {
320320

321321
/**
322322
* Called by unary expressions to generate a code block that returns null if its parent returns
323-
* null, and if not not null, use `f` to generate the expression.
323+
* null, and if not null, use `f` to generate the expression.
324324
*
325325
* As an example, the following does a boolean inversion (i.e. NOT).
326326
* {{{
@@ -340,7 +340,7 @@ abstract class UnaryExpression extends Expression {
340340

341341
/**
342342
* Called by unary expressions to generate a code block that returns null if its parent returns
343-
* null, and if not not null, use `f` to generate the expression.
343+
* null, and if not null, use `f` to generate the expression.
344344
*
345345
* @param f function that accepts the non-null evaluation result name of child and returns Java
346346
* code to compute the output.
@@ -349,20 +349,23 @@ abstract class UnaryExpression extends Expression {
349349
ctx: CodegenContext,
350350
ev: ExprCode,
351351
f: String => String): String = {
352-
val eval = child.gen(ctx)
352+
val childGen = child.gen(ctx)
353+
val resultCode = f(childGen.value)
354+
353355
if (nullable) {
354-
eval.code + s"""
355-
boolean ${ev.isNull} = ${eval.isNull};
356+
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
357+
s"""
358+
${childGen.code}
359+
boolean ${ev.isNull} = ${childGen.isNull};
356360
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
357-
if (!${eval.isNull}) {
358-
${f(eval.value)}
359-
}
361+
$nullSafeEval
360362
"""
361363
} else {
362364
ev.isNull = "false"
363-
eval.code + s"""
365+
s"""
366+
${childGen.code}
364367
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
365-
${f(eval.value)}
368+
$resultCode
366369
"""
367370
}
368371
}
@@ -440,29 +443,31 @@ abstract class BinaryExpression extends Expression {
440443
ctx: CodegenContext,
441444
ev: ExprCode,
442445
f: (String, String) => String): String = {
443-
val eval1 = left.gen(ctx)
444-
val eval2 = right.gen(ctx)
445-
val resultCode = f(eval1.value, eval2.value)
446+
val leftGen = left.gen(ctx)
447+
val rightGen = right.gen(ctx)
448+
val resultCode = f(leftGen.value, rightGen.value)
449+
446450
if (nullable) {
451+
val nullSafeEval =
452+
leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) {
453+
rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
454+
s"""
455+
${ev.isNull} = false; // resultCode could change nullability.
456+
$resultCode
457+
"""
458+
}
459+
}
460+
447461
s"""
448-
${eval1.code}
449-
boolean ${ev.isNull} = ${eval1.isNull};
462+
boolean ${ev.isNull} = true;
450463
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
451-
if (!${ev.isNull}) {
452-
${eval2.code}
453-
if (!${eval2.isNull}) {
454-
$resultCode
455-
} else {
456-
${ev.isNull} = true;
457-
}
458-
}
464+
$nullSafeEval
459465
"""
460-
461466
} else {
462467
ev.isNull = "false"
463468
s"""
464-
${eval1.code}
465-
${eval2.code}
469+
${leftGen.code}
470+
${rightGen.code}
466471
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
467472
$resultCode
468473
"""
@@ -527,7 +532,7 @@ abstract class TernaryExpression extends Expression {
527532

528533
/**
529534
* Default behavior of evaluation according to the default nullability of TernaryExpression.
530-
* If subclass of BinaryExpression override nullable, probably should also override this.
535+
* If subclass of TernaryExpression override nullable, probably should also override this.
531536
*/
532537
override def eval(input: InternalRow): Any = {
533538
val exprs = children
@@ -553,11 +558,11 @@ abstract class TernaryExpression extends Expression {
553558
sys.error(s"BinaryExpressions must override either eval or nullSafeEval")
554559

555560
/**
556-
* Short hand for generating binary evaluation code.
561+
* Short hand for generating ternary evaluation code.
557562
* If either of the sub-expressions is null, the result of this computation
558563
* is assumed to be null.
559564
*
560-
* @param f accepts two variable names and returns Java code to compute the output.
565+
* @param f accepts three variable names and returns Java code to compute the output.
561566
*/
562567
protected def defineCodeGen(
563568
ctx: CodegenContext,
@@ -569,41 +574,46 @@ abstract class TernaryExpression extends Expression {
569574
}
570575

571576
/**
572-
* Short hand for generating binary evaluation code.
577+
* Short hand for generating ternary evaluation code.
573578
* If either of the sub-expressions is null, the result of this computation
574579
* is assumed to be null.
575580
*
576-
* @param f function that accepts the 2 non-null evaluation result names of children
581+
* @param f function that accepts the 3 non-null evaluation result names of children
577582
* and returns Java code to compute the output.
578583
*/
579584
protected def nullSafeCodeGen(
580585
ctx: CodegenContext,
581586
ev: ExprCode,
582587
f: (String, String, String) => String): String = {
583-
val evals = children.map(_.gen(ctx))
584-
val resultCode = f(evals(0).value, evals(1).value, evals(2).value)
588+
val leftGen = children(0).gen(ctx)
589+
val midGen = children(1).gen(ctx)
590+
val rightGen = children(2).gen(ctx)
591+
val resultCode = f(leftGen.value, midGen.value, rightGen.value)
592+
585593
if (nullable) {
594+
val nullSafeEval =
595+
leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) {
596+
midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) {
597+
rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) {
598+
s"""
599+
${ev.isNull} = false; // resultCode could change nullability.
600+
$resultCode
601+
"""
602+
}
603+
}
604+
}
605+
586606
s"""
587-
${evals(0).code}
588607
boolean ${ev.isNull} = true;
589608
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
590-
if (!${evals(0).isNull}) {
591-
${evals(1).code}
592-
if (!${evals(1).isNull}) {
593-
${evals(2).code}
594-
if (!${evals(2).isNull}) {
595-
${ev.isNull} = false; // resultCode could change nullability
596-
$resultCode
597-
}
598-
}
599-
}
609+
$nullSafeEval
600610
"""
601611
} else {
602612
ev.isNull = "false"
603613
s"""
604-
${evals(0).code}
605-
${evals(1).code}
606-
${evals(2).code}
614+
${leftGen.code}
615+
${midGen.code}
616+
${rightGen.code}
607617
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
608618
$resultCode
609619
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -402,17 +402,37 @@ class CodegenContext {
402402
}
403403

404404
/**
405-
* Generates code for greater of two expressions.
406-
*
407-
* @param dataType data type of the expressions
408-
* @param c1 name of the variable of expression 1's output
409-
* @param c2 name of the variable of expression 2's output
410-
*/
405+
* Generates code for greater of two expressions.
406+
*
407+
* @param dataType data type of the expressions
408+
* @param c1 name of the variable of expression 1's output
409+
* @param c2 name of the variable of expression 2's output
410+
*/
411411
def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match {
412412
case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2"
413413
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
414414
}
415415

416+
/**
417+
* Generates code to do null safe execution, i.e. only execute the code when the input is not
418+
* null by adding null check if necessary.
419+
*
420+
* @param nullable used to decide whether we should add null check or not.
421+
* @param isNull the code to check if the input is null.
422+
* @param execute the code that should only be executed when the input is not null.
423+
*/
424+
def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = {
425+
if (nullable) {
426+
s"""
427+
if (!$isNull) {
428+
$execute
429+
}
430+
"""
431+
} else {
432+
"\n" + execute
433+
}
434+
}
435+
416436
/**
417437
* List of java data types that have special accessors and setters in [[InternalRow]].
418438
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
327327
ev.isNull = "false"
328328
val childrenHash = children.map { child =>
329329
val childGen = child.gen(ctx)
330-
childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
330+
childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
331331
computeHash(childGen.value, child.dataType, ev.value, ctx)
332332
}
333333
}.mkString("\n")
@@ -338,18 +338,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
338338
"""
339339
}
340340

341-
private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = {
342-
if (nullable) {
343-
s"""
344-
if (!$isNull) {
345-
$execution
346-
}
347-
"""
348-
} else {
349-
"\n" + execution
350-
}
351-
}
352-
353341
private def nullSafeElementHash(
354342
input: String,
355343
index: String,
@@ -359,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
359347
ctx: CodegenContext): String = {
360348
val element = ctx.freshName("element")
361349

362-
generateNullCheck(nullable, s"$input.isNullAt($index)") {
350+
ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
363351
s"""
364352
final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
365353
${computeHash(element, elementType, result, ctx)}

0 commit comments

Comments
 (0)