diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1eff2c4dd008..077281d383c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -139,22 +139,48 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E override def genCode(ctx: CodegenContext, ev: ExprCode): String = { // Generate code that looks like: // - // condA = ... - // if (condA) { - // valueA - // } else { + // def when_0 { + // condA = ... + // if (condA) { + // valueA + // return true + // } + // return false + // } + // + // def when_1 { // condB = ... // if (condB) { // valueB - // } else { - // condC = ... - // if (condC) { - // valueC - // } else { - // elseValue - // } + // return true // } + // return false // } + // + // def when_2 { + // elseValue + // return true + // } + // + // if (when_0()) {} else { if (when_1()) {} else { if (when_2()) {} else { }}} + + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + + ctx.addMutableState("boolean", isNull, s"boolean ${isNull} = true;") + ctx.addMutableState(ctx.javaType(dataType), value, + s"${ctx.javaType(dataType)} ${value} = ${ctx.defaultValue(dataType)};") + + def addCase(body: String) = { + val name = ctx.freshName("when_") + val code = s""" + public boolean ${name}(InternalRow ${ctx.INPUT_ROW}) { + ${body} + }""" + ctx.addNewFunction(name, code) + name + } + val cases = branches.map { case (condExpr, valueExpr) => val cond = condExpr.gen(ctx) val res = valueExpr.gen(ctx) @@ -162,31 +188,35 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E ${cond.code} if (!${cond.isNull} && ${cond.value}) { ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; + ${isNull} = ${res.isNull}; + ${value} = ${res.value}; + return true; } + return false; """ } - var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n") - - elseValue.foreach { elseExpr => + val elseCase = elseValue.map { elseExpr => val res = elseExpr.gen(ctx) - generatedCode += - s""" - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - """ + s""" + ${res.code} + ${isNull} = ${res.isNull}; + ${value} = ${res.value}; + return true; + """ } - generatedCode += "}\n" * cases.size + val names = (cases ++ elseCase).map(c => addCase(c)) + val calls = names + .map(fn => s"if (${fn}(${ctx.INPUT_ROW})) { } else { ").mkString("", "", "}" * names.size) s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $generatedCode - """ + ${isNull} = true; + ${value} = ${ctx.defaultValue(dataType)}; + $calls + boolean ${ev.isNull} = ${isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${value}; + """ } override def toString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index b5413fbe2bbc..396489ccb308 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -58,6 +58,27 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + + test("SPARK-13242: split when clauses") { + val cases = 50 + val clauses = 20 + + // Generate an individual case + def generateCase(n: Int): (Expression, Expression) = { + val condition = (1 to clauses) + .map(c => EqualTo(BoundReference(0, StringType, false), Literal(s"$c:$n"))) + .reduceLeft[Expression]((l, r) => Or(l, r)) + (condition, Literal(n)) + } + + val expression = CaseWhen((1 to cases).map(generateCase(_))) + val plan = GenerateMutableProjection.generate(Seq(expression))() + val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) + val actual = plan(input).toSeq(Seq(expression.dataType)) + + assert(actual(0) == cases) + } + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true),