Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
improve codegen
  • Loading branch information
Davies Liu committed Feb 19, 2016
commit c92e45717cc2e23e5a9f552f99b7b1233d161f8b
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
ev.isNull = ctx.currentVars(ordinal).isNull
ev.value = ctx.currentVars(ordinal).value
val oev = ctx.currentVars(ordinal)
// assert(oev.code == "", s"$this has not been evaluated yet.")
ev.isNull = oev.isNull
ev.value = oev.value
""
} else if (nullable) {
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,10 @@ case class Expand(

val numOutput = metricTerm(ctx, "numOutputRows")
val i = ctx.freshName("i")
// these column have to declared before the loop.
val evaluate = evaluateVariables(outputColumns)
s"""
|${outputColumns.map(_.code).mkString("\n").trim}
|$evaluate
|for (int $i = 0; $i < ${projections.length}; $i ++) {
| switch ($i) {
| ${cases.mkString("\n").trim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ trait CodegenSupport extends SparkPlan {
def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
doProduce(ctx)
s"""
|/*** PRODUCE: ${commentSafe(this.simpleString)} */
|${doProduce(ctx)}
""".stripMargin
}

/**
Expand Down Expand Up @@ -108,6 +111,38 @@ trait CodegenSupport extends SparkPlan {
parent.consumeChild(ctx, this, input, row)
}

/**
* Returns source code to evaluate all the variables, and clear the code of them, to prevent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a high level comment that describes the overall framework? I think the important things to include are:

  • how it works in general?
  • how should an operator that does not short circuit (e.g. project/sort) use this?
  • how should an operator that does short circuit use this (if different)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was imagining something like:

evaluateAttributes(Seq[Expression]) which evaluates all the attribute refernces in the tree that haven't been. This is kind of similar to what you have below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some variables could be generated in the middle of the plan, for example, aggregate, and join, so we can't always use the references of current plan to determine which expression is used or not. So I have two different functions here, we could pass in the used references to the function below.

* them to be evaluated twice.
*/
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment for ExprCode.code to specify what it means when it is empty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
variables.foreach(_.code = "")
evaluate
}

/**
* Returns source code to evaluate the variables for required attributes, and clear the code
* of evaluated variables, to prevent them to be evaluated twice..
*/
protected def evaluateRequiredVariables(
attributes: Seq[Attribute],
variables: Seq[ExprCode],
required: AttributeSet): String = {
var evaluateVars = ""
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies I was just reviewing build warnings, and it flags this line. ev.code is a Block rather than String. Should it be ev.code.nonEmpty && ... instead?

evaluateVars += ev.code.trim + "\n"
ev.code = ""
}
}
evaluateVars
}

protected def commentSafe(s: String): String = {
s.replace("*/", "\\*\\/").replace("\\u", "\\\\u")
}

/**
* Consume the columns generated from it's child, call doConsume() or emit the rows.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment what the semantics are if row == null vs not null?

*/
Expand All @@ -117,19 +152,22 @@ trait CodegenSupport extends SparkPlan {
input: Seq[ExprCode],
row: String = null): String = {
ctx.freshNamePrefix = variablePrefix
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
val evals = child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
val inputVars =
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
}
} else {
input
}
s"""
| ${evals.map(_.code).mkString("\n")}
| ${doConsume(ctx, evals)}
""".stripMargin
} else {
doConsume(ctx, input)
}
s"""
|
|/*** CONSUME: ${commentSafe(this.simpleString)} */
|${evaluateRequiredVariables(child.output, inputVars, references)}
|${doConsume(ctx, inputVars)}
""".stripMargin
}

/**
Expand Down Expand Up @@ -183,13 +221,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
s"""
| while (input.hasNext()) {
| while (!shouldStop() && input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) {
| return;
| }
| }
""".stripMargin
}
Expand Down Expand Up @@ -251,7 +285,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}

/** Codegened pipeline for:
* ${plan.treeString.trim}
* ${commentSafe(plan.treeString.trim)}
*/
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {

Expand Down Expand Up @@ -305,7 +339,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
if (row != null) {
// There is an UnsafeRow already
s"""
| currentRows.add($row.copy());
|currentRows.add($row.copy());
""".stripMargin
} else {
assert(input != null)
Expand All @@ -317,13 +351,14 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
| ${code.code.trim}
| currentRows.add(${code.value}.copy());
|${evaluateVariables(input)}
|${code.code.trim}
|currentRows.add(${code.value}.copy());
""".stripMargin
} else {
// There is no columns
s"""
| currentRows.add(unsafeRow);
|currentRows.add(unsafeRow);
""".stripMargin
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ case class TungstenAggregate(
// all the mode of aggregate expressions
private val modes = aggregateExpressions.map(_.mode).distinct

override def references: AttributeSet = {
AttributeSet(groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap {
case AggregateExpression(f, Final | PartialMerge, _) => f.inputAggBufferAttributes
case AggregateExpression(f, Partial | Complete, _) => f.references
})
child.outputSet
}

override def supportCodegen: Boolean = {
// ImperativeAggregate is not supported right now
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
Expand Down Expand Up @@ -164,47 +172,47 @@ case class TungstenAggregate(
""".stripMargin
ExprCode(ev.code + initVars, isNull, value)
}
val initBufVar = evaluateVariables(bufVars)

// generate variables for output
val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
// evaluate aggregate results
ctx.currentVars = bufVars
val aggResults = functions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, bufferAttrs).gen(ctx)
BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
}
// evaluate result expressions
ctx.currentVars = aggResults
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
}
(resultVars, s"""
| ${aggResults.map(_.code).mkString("\n")}
| ${resultVars.map(_.code).mkString("\n")}
| ${evaluateVariables(aggResults)}
| ${evaluateVariables(resultVars)}
""".stripMargin)
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
// output the aggregate buffer directly
(bufVars, "")
} else {
// no aggregate function, the result should be literals
val resultVars = resultExpressions.map(_.gen(ctx))
(resultVars, resultVars.map(_.code).mkString("\n"))
(resultVars, evaluateVariables(resultVars))
}

val doAgg = ctx.freshName("doAggregateWithoutKey")
ctx.addNewFunction(doAgg,
s"""
| private void $doAgg() throws java.io.IOException {
| // initialize aggregation buffer
| ${bufVars.map(_.code).mkString("\n")}
| $initBufVar
|
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
""".stripMargin)

val numOutput = metricTerm(ctx, "numOutputRows")
s"""
| if (!$initAgg) {
| while (!$initAgg) {
| $initAgg = true;
| $doAgg();
|
Expand Down Expand Up @@ -241,7 +249,7 @@ case class TungstenAggregate(
}
s"""
| // do aggregate
| ${aggVals.map(_.code).mkString("\n").trim}
| ${evaluateVariables(aggVals)}
| // update aggregation buffer
| ${updates.mkString("\n").trim}
""".stripMargin
Expand All @@ -252,8 +260,7 @@ case class TungstenAggregate(
private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
.filter(_.isInstanceOf[DeclarativeAggregate])
.map(_.asInstanceOf[DeclarativeAggregate])
private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
private val bufferSchema = StructType.fromAttributes(bufferAttributes)
private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes)

// The name for HashMap
private var hashMapTerm: String = _
Expand Down Expand Up @@ -318,7 +325,7 @@ case class TungstenAggregate(
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = newMutableProjection(
mergeExpr,
bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
subexpressionEliminationEnabled)()
val joinedRow = new JoinedRow()

Expand Down Expand Up @@ -381,13 +388,13 @@ case class TungstenAggregate(
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
ctx.INPUT_ROW = bufferTerm
val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
// evaluate the aggregation result
ctx.currentVars = bufferVars
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, bufferAttributes).gen(ctx)
BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
}
// generate the final result
ctx.currentVars = keyVars ++ aggResults
Expand All @@ -396,11 +403,9 @@ case class TungstenAggregate(
BindReferences.bindReference(e, inputAttrs).gen(ctx)
}
s"""
${keyVars.map(_.code).mkString("\n")}
${bufferVars.map(_.code).mkString("\n")}
${aggResults.map(_.code).mkString("\n")}
${resultVars.map(_.code).mkString("\n")}

${evaluateVariables(keyVars)}
${evaluateVariables(bufferVars)}
${evaluateVariables(aggResults)}
${consume(ctx, resultVars)}
"""

Expand All @@ -422,10 +427,7 @@ case class TungstenAggregate(
val eval = resultExpressions.map{ e =>
BindReferences.bindReference(e, groupingAttributes).gen(ctx)
}
s"""
${eval.map(_.code).mkString("\n")}
${consume(ctx, eval)}
"""
consume(ctx, eval)
}
}

Expand Down Expand Up @@ -508,8 +510,8 @@ case class TungstenAggregate(
ctx.currentVars = input
val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)

val inputAttr = bufferAttributes ++ child.output
ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
val inputAttr = aggregateBufferAttributes ++ child.output
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
ctx.INPUT_ROW = buffer
// TODO: support subexpression elimination
val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
Expand Down Expand Up @@ -557,7 +559,7 @@ case class TungstenAggregate(
$incCounter

// evaluate aggregate function
${evals.map(_.code).mkString("\n").trim}
${evaluateVariables(evals)}
// update aggregate buffer
${updates.mkString("\n").trim}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input
val output = exprs.map(_.gen(ctx))
val resultVars = exprs.map(_.gen(ctx))
// Evaluation of non-deterministic expressions can't be deferred.
val nonDeterministicAttrs = projectList.zip(output).filter(!_._1.deterministic).unzip._2
s"""
| ${output.map(_.code).mkString("\n")}
|
| ${consume(ctx, output)}
|${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
|${consume(ctx, resultVars)}
""".stripMargin
}

Expand Down Expand Up @@ -89,11 +90,10 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
s""
}
s"""
| ${eval.code}
| if ($nullCheck ${eval.value}) {
| $numOutput.add(1);
| ${consume(ctx, ctx.currentVars)}
| }
|${eval.code}
|if (!($nullCheck ${eval.value})) continue;
|$numOutput.add(1);
|${consume(ctx, ctx.currentVars)}
""".stripMargin
}

Expand Down Expand Up @@ -224,15 +224,13 @@ case class Range(
| }
| }
|
| while (!$overflow && $checkEnd) {
| while (!$overflow && $checkEnd && !shouldStop()) {
| long $value = $number;
| $number += ${step}L;
| if ($number < $value ^ ${step}L < 0) {
| $overflow = true;
| }
| ${consume(ctx, Seq(ev))}
|
| if (shouldStop()) return;
| }
""".stripMargin
}
Expand Down
Loading