diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 40b1eec63e55..c4e5b844299a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -110,6 +110,14 @@ trait CaseWhenLike extends Expression { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) } + + /** + * Whether should it fallback to interpret mode or not. + * @return + */ + protected def shouldFallback: Boolean = { + branches.length > 20 + } } // scalastyle:off @@ -119,7 +127,7 @@ trait CaseWhenLike extends Expression { * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions */ // scalastyle:on -case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { +case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike with CodegenFallback { // Use private[this] Array to speed up evaluation. @transient private[this] lazy val branchesArr = branches.toArray @@ -157,6 +165,11 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + if (shouldFallback) { + // Fallback to interpreted mode if there are too many branches, as it may reach the + // 64K limit (limit on bytecode size for a single function). + return super[CodegenFallback].genCode(ctx, ev) + } val len = branchesArr.length val got = ctx.freshName("got") @@ -213,7 +226,8 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions */ // scalastyle:on -case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { +case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) + extends CaseWhenLike with CodegenFallback { // Use private[this] Array to speed up evaluation. @transient private[this] lazy val branchesArr = branches.toArray @@ -257,6 +271,11 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + if (shouldFallback) { + // Fallback to interpreted mode if there are too many branches, as it may reach the + // 64K limit (limit on bytecode size for a single function). + return super[CodegenFallback].genCode(ctx, ev) + } val keyEval = key.gen(ctx) val len = branchesArr.length val got = ctx.freshName("got") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index cd2ef7dcd0cd..66a7e228f84b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -59,6 +59,28 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-13242: case-when expression with large number of branches (or cases)") { + val cases = 50 + val clauses = 20 + + // Generate an individual case + def generateCase(n: Int): Seq[Expression] = { + val condition = (1 to clauses) + .map(c => EqualTo(BoundReference(0, StringType, false), Literal(s"$c:$n"))) + .reduceLeft[Expression]((l, r) => Or(l, r)) + Seq(condition, Literal(n)) + } + + val expression = CaseWhen((1 to cases).flatMap(generateCase(_))) + + val plan = GenerateMutableProjection.generate(Seq(expression))() + val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) + val actual = plan(input).toSeq(Seq(expression.dataType)) + + assert(actual(0) == cases) + } + + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true),