Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,54 +139,84 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
// Generate code that looks like:
//
// condA = ...
// if (condA) {
// valueA
// } else {
// def when_0 {
// condA = ...
// if (condA) {
// valueA
// return true
// }
// return false
// }
//
// def when_1 {
// condB = ...
// if (condB) {
// valueB
// } else {
// condC = ...
// if (condC) {
// valueC
// } else {
// elseValue
// }
// return true
// }
// return false
// }
//
// def when_2 {
// elseValue
// return true
// }
//
// if (when_0()) {} else { if (when_1()) {} else { if (when_2()) {} else { }}}

val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")

ctx.addMutableState("boolean", isNull, s"boolean ${isNull} = true;")
ctx.addMutableState(ctx.javaType(dataType), value,
s"${ctx.javaType(dataType)} ${value} = ${ctx.defaultValue(dataType)};")

def addCase(body: String) = {
val name = ctx.freshName("when_")
val code = s"""
public boolean ${name}(InternalRow ${ctx.INPUT_ROW}) {
${body}
}"""
ctx.addNewFunction(name, code)
name
}

val cases = branches.map { case (condExpr, valueExpr) =>
val cond = condExpr.gen(ctx)
val res = valueExpr.gen(ctx)
s"""
${cond.code}
if (!${cond.isNull} && ${cond.value}) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
${isNull} = ${res.isNull};
${value} = ${res.value};
return true;
}
return false;
"""
}

var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")

elseValue.foreach { elseExpr =>
val elseCase = elseValue.map { elseExpr =>
val res = elseExpr.gen(ctx)
generatedCode +=
s"""
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
"""
s"""
${res.code}
${isNull} = ${res.isNull};
${value} = ${res.value};
return true;
"""
}

generatedCode += "}\n" * cases.size
val names = (cases ++ elseCase).map(c => addCase(c))
val calls = names
.map(fn => s"if (${fn}(${ctx.INPUT_ROW})) { } else { ").mkString("", "", "}" * names.size)

s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$generatedCode
"""
${isNull} = true;
${value} = ${ctx.defaultValue(dataType)};
$calls
boolean ${ev.isNull} = ${isNull};
${ctx.javaType(dataType)} ${ev.value} = ${value};
"""
}

override def toString: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,27 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}


test("SPARK-13242: split when clauses") {
val cases = 50
val clauses = 20

// Generate an individual case
def generateCase(n: Int): (Expression, Expression) = {
val condition = (1 to clauses)
.map(c => EqualTo(BoundReference(0, StringType, false), Literal(s"$c:$n")))
.reduceLeft[Expression]((l, r) => Or(l, r))
(condition, Literal(n))
}

val expression = CaseWhen((1 to cases).map(generateCase(_)))
val plan = GenerateMutableProjection.generate(Seq(expression))()
val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}")))
val actual = plan(input).toSeq(Seq(expression.dataType))

assert(actual(0) == cases)
}

test("test generated safe and unsafe projection") {
val schema = new StructType(Array(
StructField("a", StringType, true),
Expand Down