From 5b1097f925dca47a3998edbea3f60bd6913400eb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 Jan 2016 12:01:49 -0800 Subject: [PATCH 1/2] improve nullSafeCodeGen for unary, binary and ternary expression --- .../sql/catalyst/expressions/Expression.scala | 118 +++++++++--------- .../expressions/codegen/CodeGenerator.scala | 31 ++++- .../spark/sql/catalyst/expressions/misc.scala | 16 +-- 3 files changed, 87 insertions(+), 78 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index db17ba7c84ff..9c7ce844182d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -320,7 +320,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * As an example, the following does a boolean inversion (i.e. NOT). * {{{ @@ -350,21 +350,23 @@ abstract class UnaryExpression extends Expression { ev: ExprCode, f: String => String): String = { val eval = child.gen(ctx) - if (nullable) { - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${eval.isNull}) { - ${f(eval.value)} - } - """ + + val declareIsNull = if (nullable) { + // If this expression is nullable, which means the `f` may need to change `ev.isNull` even if + // child is not nullable, so we make `ev.isNull` a variable here. + s"boolean ${ev.isNull} = ${eval.isNull};" } else { + // If this expression is not nullable, which means `ev.isNull` will always be false, thus we + // don't need to declare the isNull variable. ev.isNull = "false" - eval.code + s""" - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${f(eval.value)} - """ + "" } + + s""" + ${eval.code} + $declareIsNull + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + """ + ctx.genNullCheck(child.nullable, eval.isNull)(f(eval.value)) } override def sql: String = s"($prettyName(${child.sql}))" @@ -442,30 +444,30 @@ abstract class BinaryExpression extends Expression { f: (String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val resultCode = f(eval1.value, eval2.value) - if (nullable) { - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } - """ + val declareIsNull = if (nullable) { + // If this expression is nullable, which means the `f` may need to change `ev.isNull` even if + // child is not nullable, so we make `ev.isNull` a variable here. + s"boolean ${ev.isNull} = true;" } else { + // If this expression is not nullable, which means `ev.isNull` will always be false, thus we + // don't need to declare the `ev.isNull` variable. ev.isNull = "false" - s""" - ${eval1.code} - ${eval2.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode - """ + "" + } + + s""" + $declareIsNull + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + """ + eval1.code + ctx.genNullCheck(left.nullable, eval1.isNull) { + eval2.code + ctx.genNullCheck(right.nullable, eval2.isNull) { + val setNotNull = if (nullable) s"${ev.isNull} = false;" else "" + val resultCode = f(eval1.value, eval2.value) + s""" + $setNotNull // resultCode could change nullability, so this should sit before it. + $resultCode + """ + } } } @@ -581,32 +583,32 @@ abstract class TernaryExpression extends Expression { ev: ExprCode, f: (String, String, String) => String): String = { val evals = children.map(_.gen(ctx)) - val resultCode = f(evals(0).value, evals(1).value, evals(2).value) - if (nullable) { - s""" - ${evals(0).code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${evals(0).isNull}) { - ${evals(1).code} - if (!${evals(1).isNull}) { - ${evals(2).code} - if (!${evals(2).isNull}) { - ${ev.isNull} = false; // resultCode could change nullability - $resultCode - } - } - } - """ + + val declareIsNull = if (nullable) { + // If this expression is nullable, which means the `f` may need to change `ev.isNull` even if + // child is not nullable, so we make `ev.isNull` a variable here. + s"boolean ${ev.isNull} = true;" } else { + // If this expression is not nullable, which means `ev.isNull` will always be false, thus we + // don't need to declare the `ev.isNull` variable. ev.isNull = "false" - s""" - ${evals(0).code} - ${evals(1).code} - ${evals(2).code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode - """ + "" + } + + s""" + $declareIsNull + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + """ + evals(0).code + ctx.genNullCheck(children(0).nullable, evals(0).isNull) { + evals(1).code + ctx.genNullCheck(children(1).nullable, evals(1).isNull) { + evals(2).code + ctx.genNullCheck(children(2).nullable, evals(2).isNull) { + val setNotNull = if (nullable) s"${ev.isNull} = false;" else "" + val resultCode = f(evals(0).value, evals(1).value, evals(2).value) + s""" + $setNotNull // resultCode could change nullability, so this should sit before it. + $resultCode + """ + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e6704cf8bb1f..22192258a4ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -355,17 +355,36 @@ class CodegenContext { } /** - * Generates code for greater of two expressions. - * - * @param dataType data type of the expressions - * @param c1 name of the variable of expression 1's output - * @param c2 name of the variable of expression 2's output - */ + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for adding null check if necessary. + * + * @param nullable we can avoid null check if it's false. + * @param isNull the code to check null. + * @param execution the code to run execution. + */ + def genNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execution + } + """ + } else { + "\n" + execution + } + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8480c3f9a12f..5e355b232c84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ev.isNull = "false" val childrenHash = children.map { child => val childGen = child.gen(ctx) - childGen.code + generateNullCheck(child.nullable, childGen.isNull) { + childGen.code + ctx.genNullCheck(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) } }.mkString("\n") @@ -338,18 +338,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression """ } - private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { - if (nullable) { - s""" - if (!$isNull) { - $execution - } - """ - } else { - "\n" + execution - } - } - private def nullSafeElementHash( input: String, index: String, @@ -359,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ctx: CodegenContext): String = { val element = ctx.freshName("element") - generateNullCheck(nullable, s"$input.isNullAt($index)") { + ctx.genNullCheck(nullable, s"$input.isNullAt($index)") { s""" final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} From 566006e34a5d0acfae81931ac1627a68ace6543d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 31 Jan 2016 16:01:44 -0800 Subject: [PATCH 2/2] improve readability --- .../sql/catalyst/expressions/Expression.scala | 144 +++++++++--------- .../expressions/codegen/CodeGenerator.scala | 15 +- .../spark/sql/catalyst/expressions/misc.scala | 4 +- 3 files changed, 86 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9c7ce844182d..353fb92581d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -340,7 +340,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * @param f function that accepts the non-null evaluation result name of child and returns Java * code to compute the output. @@ -349,24 +349,25 @@ abstract class UnaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: String => String): String = { - val eval = child.gen(ctx) - - val declareIsNull = if (nullable) { - // If this expression is nullable, which means the `f` may need to change `ev.isNull` even if - // child is not nullable, so we make `ev.isNull` a variable here. - s"boolean ${ev.isNull} = ${eval.isNull};" + val childGen = child.gen(ctx) + val resultCode = f(childGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) + s""" + ${childGen.code} + boolean ${ev.isNull} = ${childGen.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $nullSafeEval + """ } else { - // If this expression is not nullable, which means `ev.isNull` will always be false, thus we - // don't need to declare the isNull variable. ev.isNull = "false" - "" + s""" + ${childGen.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ } - - s""" - ${eval.code} - $declareIsNull - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + ctx.genNullCheck(child.nullable, eval.isNull)(f(eval.value)) } override def sql: String = s"($prettyName(${child.sql}))" @@ -442,32 +443,34 @@ abstract class BinaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: (String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + val leftGen = left.gen(ctx) + val rightGen = right.gen(ctx) + val resultCode = f(leftGen.value, rightGen.value) + + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) { + rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } + } - val declareIsNull = if (nullable) { - // If this expression is nullable, which means the `f` may need to change `ev.isNull` even if - // child is not nullable, so we make `ev.isNull` a variable here. - s"boolean ${ev.isNull} = true;" + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $nullSafeEval + """ } else { - // If this expression is not nullable, which means `ev.isNull` will always be false, thus we - // don't need to declare the `ev.isNull` variable. ev.isNull = "false" - "" - } - - s""" - $declareIsNull - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + eval1.code + ctx.genNullCheck(left.nullable, eval1.isNull) { - eval2.code + ctx.genNullCheck(right.nullable, eval2.isNull) { - val setNotNull = if (nullable) s"${ev.isNull} = false;" else "" - val resultCode = f(eval1.value, eval2.value) - s""" - $setNotNull // resultCode could change nullability, so this should sit before it. - $resultCode - """ - } + s""" + ${leftGen.code} + ${rightGen.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ } } @@ -529,7 +532,7 @@ abstract class TernaryExpression extends Expression { /** * Default behavior of evaluation according to the default nullability of TernaryExpression. - * If subclass of BinaryExpression override nullable, probably should also override this. + * If subclass of TernaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { val exprs = children @@ -555,11 +558,11 @@ abstract class TernaryExpression extends Expression { sys.error(s"BinaryExpressions must override either eval or nullSafeEval") /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f accepts two variable names and returns Java code to compute the output. + * @param f accepts three variable names and returns Java code to compute the output. */ protected def defineCodeGen( ctx: CodegenContext, @@ -571,44 +574,49 @@ abstract class TernaryExpression extends Expression { } /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f function that accepts the 2 non-null evaluation result names of children + * @param f function that accepts the 3 non-null evaluation result names of children * and returns Java code to compute the output. */ protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, f: (String, String, String) => String): String = { - val evals = children.map(_.gen(ctx)) + val leftGen = children(0).gen(ctx) + val midGen = children(1).gen(ctx) + val rightGen = children(2).gen(ctx) + val resultCode = f(leftGen.value, midGen.value, rightGen.value) + + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) { + midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) { + rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } + } + } - val declareIsNull = if (nullable) { - // If this expression is nullable, which means the `f` may need to change `ev.isNull` even if - // child is not nullable, so we make `ev.isNull` a variable here. - s"boolean ${ev.isNull} = true;" + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $nullSafeEval + """ } else { - // If this expression is not nullable, which means `ev.isNull` will always be false, thus we - // don't need to declare the `ev.isNull` variable. ev.isNull = "false" - "" - } - - s""" - $declareIsNull - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + evals(0).code + ctx.genNullCheck(children(0).nullable, evals(0).isNull) { - evals(1).code + ctx.genNullCheck(children(1).nullable, evals(1).isNull) { - evals(2).code + ctx.genNullCheck(children(2).nullable, evals(2).isNull) { - val setNotNull = if (nullable) s"${ev.isNull} = false;" else "" - val resultCode = f(evals(0).value, evals(1).value, evals(2).value) - s""" - $setNotNull // resultCode could change nullability, so this should sit before it. - $resultCode - """ - } - } + s""" + ${leftGen.code} + ${midGen.code} + ${rightGen.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 22192258a4ba..9fb521291919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -367,21 +367,22 @@ class CodegenContext { } /** - * Generates code for adding null check if necessary. + * Generates code to do null safe execution, i.e. only execute the code when the input is not + * null by adding null check if necessary. * - * @param nullable we can avoid null check if it's false. - * @param isNull the code to check null. - * @param execution the code to run execution. + * @param nullable used to decide whether we should add null check or not. + * @param isNull the code to check if the input is null. + * @param execute the code that should only be executed when the input is not null. */ - def genNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { + def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = { if (nullable) { s""" if (!$isNull) { - $execution + $execute } """ } else { - "\n" + execution + "\n" + execute } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 5e355b232c84..36e1fa1176d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ev.isNull = "false" val childrenHash = children.map { child => val childGen = child.gen(ctx) - childGen.code + ctx.genNullCheck(child.nullable, childGen.isNull) { + childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) } }.mkString("\n") @@ -347,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ctx: CodegenContext): String = { val element = ctx.freshName("element") - ctx.genNullCheck(nullable, s"$input.isNullAt($index)") { + ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)}