Skip to content

Commit aace0a3

Browse files
author
Davies Liu
committed
fallback in case-when
1 parent e430614 commit aace0a3

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
8686
* @param elseValue optional value for the else branch
8787
*/
8888
case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
89-
extends Expression {
89+
extends Expression with CodegenFallback {
9090

9191
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
9292

@@ -136,7 +136,16 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
136136
}
137137
}
138138

139+
def shouldCodegen: Boolean = {
140+
branches.length < CaseWhen.MAX_NUMBER_OF_SWITCHES
141+
}
142+
139143
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
144+
if (!shouldCodegen) {
145+
// Fallback to interpreted mode if there are too many branches, or it may reach the
146+
// 64K limit (number of bytecode for single Java method).
147+
return super.genCode(ctx, ev)
148+
}
140149
// Generate code that looks like:
141150
//
142151
// condA = ...
@@ -205,6 +214,9 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
205214
/** Factory methods for CaseWhen. */
206215
object CaseWhen {
207216

217+
// The maxium number of switches supported with codegen.
218+
val MAX_NUMBER_OF_SWITCHES = 20
219+
208220
def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
209221
CaseWhen(branches, Option(elseValue))
210222
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,27 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
5858
}
5959
}
6060

61+
test("SPARK-13242: complicated case-when expressions") {
62+
val cases = 50
63+
val clauses = 20
64+
65+
// Generate an individual case
66+
def generateCase(n: Int): (Expression, Expression) = {
67+
val condition = (1 to clauses)
68+
.map(c => EqualTo(BoundReference(0, StringType, false), Literal(s"$c:$n")))
69+
.reduceLeft[Expression]((l, r) => Or(l, r))
70+
(condition, Literal(n))
71+
}
72+
73+
val expression = CaseWhen((1 to cases).map(generateCase(_)))
74+
75+
val plan = GenerateMutableProjection.generate(Seq(expression))()
76+
val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}")))
77+
val actual = plan(input).toSeq(Seq(expression.dataType))
78+
79+
assert(actual(0) == cases)
80+
}
81+
6182
test("test generated safe and unsafe projection") {
6283
val schema = new StructType(Array(
6384
StructField("a", StringType, true),

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
416416

417417
private def supportCodegen(e: Expression): Boolean = e match {
418418
case e: LeafExpression => true
419+
case e: CaseWhen => e.shouldCodegen
419420
// CodegenFallback requires the input to be an InternalRow
420421
case e: CodegenFallback => false
421422
case _ => true

0 commit comments

Comments
 (0)