Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
ev.isNull = ctx.currentVars(ordinal).isNull
ev.value = ctx.currentVars(ordinal).value
""
val oev = ctx.currentVars(ordinal)
ev.isNull = oev.isNull
ev.value = oev.value
oev.code
} else if (nullable) {
s"""
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ private[sql] case class PhysicalRDD(
| while ($input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
| $numOutputRows.add(1);
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) {
| return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ case class Expand(

val numOutput = metricTerm(ctx, "numOutputRows")
val i = ctx.freshName("i")
// these column have to declared before the loop.
val evaluate = evaluateVariables(outputColumns)
s"""
|${outputColumns.map(_.code).mkString("\n").trim}
|$evaluate
|for (int $i = 0; $i < ${projections.length}; $i ++) {
| switch ($i) {
| ${cases.mkString("\n").trim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ trait CodegenSupport extends SparkPlan {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
waitForSubqueries()
doProduce(ctx)
s"""
|/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
|${doProduce(ctx)}
""".stripMargin
}

/**
Expand Down Expand Up @@ -115,6 +118,39 @@ trait CodegenSupport extends SparkPlan {
parent.consumeChild(ctx, this, input, row)
}

/**
* Returns source code to evaluate all the variables, and clear the code of them, to prevent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a high level comment that describes the overall framework? I think the important things to include are:

  • how it works in general?
  • how should an operator that does not short circuit (e.g. project/sort) use this?
  • how should an operator that does short circuit use this (if different)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was imagining something like:

evaluateAttributes(Seq[Expression]) which evaluates all the attribute refernces in the tree that haven't been. This is kind of similar to what you have below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some variables could be generated in the middle of the plan, for example, aggregate, and join, so we can't always use the references of current plan to determine which expression is used or not. So I have two different functions here, we could pass in the used references to the function below.

* them to be evaluated twice.
*/
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment for ExprCode.code to specify what it means when it is empty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
variables.foreach(_.code = "")
evaluate
}

/**
* Returns source code to evaluate the variables for required attributes, and clear the code
* of evaluated variables, to prevent them to be evaluated twice..
*/
protected def evaluateRequiredVariables(
attributes: Seq[Attribute],
variables: Seq[ExprCode],
required: AttributeSet): String = {
var evaluateVars = ""
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies I was just reviewing build warnings, and it flags this line. ev.code is a Block rather than String. Should it be ev.code.nonEmpty && ... instead?

evaluateVars += ev.code.trim + "\n"
ev.code = ""
}
}
evaluateVars
}

/**
* The subset of inputSet those should be evaluated before this plan.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good place to document how this whole thing works in a couple of sentences. Something describing that we defer attribute access in the generated function. We access all the attributes needed by the operator at the beginning if it was not already referenced earlier in the pipeline.

Might also update the commit message with this since this is what most of the patch is about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

*/
def usedInputs: AttributeSet = references

/**
* Consume the columns generated from it's child, call doConsume() or emit the rows.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment what the semantics are if row == null vs not null?

*/
Expand All @@ -124,19 +160,22 @@ trait CodegenSupport extends SparkPlan {
input: Seq[ExprCode],
row: String = null): String = {
ctx.freshNamePrefix = variablePrefix
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
val evals = child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
val inputVars =
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
}
} else {
input
}
s"""
| ${evals.map(_.code).mkString("\n")}
| ${doConsume(ctx, evals)}
""".stripMargin
} else {
doConsume(ctx, input)
}
s"""
|
|/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
|${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
|${doConsume(ctx, inputVars)}
""".stripMargin
}

/**
Expand Down Expand Up @@ -198,13 +237,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
s"""
| while ($input.hasNext()) {
| while (!shouldStop() && $input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) {
| return;
| }
| }
""".stripMargin
}
Expand Down Expand Up @@ -345,10 +380,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
val evaluateInputs = evaluateVariables(input)
// generate the code to create a UnsafeRow
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
|$evaluateInputs
|${code.code.trim}
|append(${code.value}.copy());
""".stripMargin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ case class TungstenAggregate(
// all the mode of aggregate expressions
private val modes = aggregateExpressions.map(_.mode).distinct

override def usedInputs: AttributeSet = inputSet

override def supportCodegen: Boolean = {
// ImperativeAggregate is not supported right now
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
Expand Down Expand Up @@ -164,47 +166,48 @@ case class TungstenAggregate(
""".stripMargin
ExprCode(ev.code + initVars, isNull, value)
}
val initBufVar = evaluateVariables(bufVars)

// generate variables for output
val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
// evaluate aggregate results
ctx.currentVars = bufVars
val aggResults = functions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, bufferAttrs).gen(ctx)
BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
}
val evaluateAggResults = evaluateVariables(aggResults)
// evaluate result expressions
ctx.currentVars = aggResults
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
}
(resultVars, s"""
| ${aggResults.map(_.code).mkString("\n")}
| ${resultVars.map(_.code).mkString("\n")}
|$evaluateAggResults
|${evaluateVariables(resultVars)}
""".stripMargin)
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
// output the aggregate buffer directly
(bufVars, "")
} else {
// no aggregate function, the result should be literals
val resultVars = resultExpressions.map(_.gen(ctx))
(resultVars, resultVars.map(_.code).mkString("\n"))
(resultVars, evaluateVariables(resultVars))
}

val doAgg = ctx.freshName("doAggregateWithoutKey")
ctx.addNewFunction(doAgg,
s"""
| private void $doAgg() throws java.io.IOException {
| // initialize aggregation buffer
| ${bufVars.map(_.code).mkString("\n")}
| $initBufVar
|
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
""".stripMargin)

val numOutput = metricTerm(ctx, "numOutputRows")
s"""
| if (!$initAgg) {
| while (!$initAgg) {
| $initAgg = true;
| $doAgg();
|
Expand Down Expand Up @@ -241,7 +244,7 @@ case class TungstenAggregate(
}
s"""
| // do aggregate
| ${aggVals.map(_.code).mkString("\n").trim}
| ${evaluateVariables(aggVals)}
| // update aggregation buffer
| ${updates.mkString("\n").trim}
""".stripMargin
Expand All @@ -252,8 +255,7 @@ case class TungstenAggregate(
private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
.filter(_.isInstanceOf[DeclarativeAggregate])
.map(_.asInstanceOf[DeclarativeAggregate])
private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
private val bufferSchema = StructType.fromAttributes(bufferAttributes)
private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes)

// The name for HashMap
private var hashMapTerm: String = _
Expand Down Expand Up @@ -318,7 +320,7 @@ case class TungstenAggregate(
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = newMutableProjection(
mergeExpr,
bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
subexpressionEliminationEnabled)()
val joinedRow = new JoinedRow()

Expand Down Expand Up @@ -380,27 +382,28 @@ case class TungstenAggregate(
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
val evaluateKeyVars = evaluateVariables(keyVars)
ctx.INPUT_ROW = bufferTerm
val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
val evaluateBufferVars = evaluateVariables(bufferVars)
// evaluate the aggregation result
ctx.currentVars = bufferVars
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, bufferAttributes).gen(ctx)
BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
}
val evaluateAggResults = evaluateVariables(aggResults)
// generate the final result
ctx.currentVars = keyVars ++ aggResults
val inputAttrs = groupingAttributes ++ aggregateAttributes
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).gen(ctx)
}
s"""
${keyVars.map(_.code).mkString("\n")}
${bufferVars.map(_.code).mkString("\n")}
${aggResults.map(_.code).mkString("\n")}
${resultVars.map(_.code).mkString("\n")}

$evaluateKeyVars
$evaluateBufferVars
$evaluateAggResults
${consume(ctx, resultVars)}
"""

Expand All @@ -422,10 +425,7 @@ case class TungstenAggregate(
val eval = resultExpressions.map{ e =>
BindReferences.bindReference(e, groupingAttributes).gen(ctx)
}
s"""
${eval.map(_.code).mkString("\n")}
${consume(ctx, eval)}
"""
consume(ctx, eval)
}
}

Expand Down Expand Up @@ -508,8 +508,8 @@ case class TungstenAggregate(
ctx.currentVars = input
val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)

val inputAttr = bufferAttributes ++ child.output
ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
val inputAttr = aggregateBufferAttributes ++ child.output
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
ctx.INPUT_ROW = buffer
// TODO: support subexpression elimination
val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
Expand Down Expand Up @@ -557,7 +557,7 @@ case class TungstenAggregate(
$incCounter

// evaluate aggregate function
${evals.map(_.code).mkString("\n").trim}
${evaluateVariables(evals)}
// update aggregate buffer
${updates.mkString("\n").trim}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,26 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def usedInputs: AttributeSet = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? I think we need to have stronger semantics so that this is not necessary. I think if each operator just always ensured the referenced attributes were populated at the start of consume, we don't need this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, this is need to handle the cases that an Attribute could appear twice in projectList, but we can't generate the code for an Attribute twice.

// only the attributes those are used at least twice should be evaluated before this plan,
// otherwise we could defer the evaluation until output attribute is actually used.
val usedExprIds = projectList.flatMap(_.collect {
case a: Attribute => a.exprId
})
val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 1).keySet
references.filter(a => usedMoreThanOnce.contains(a.exprId))
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input
val output = exprs.map(_.gen(ctx))
val resultVars = exprs.map(_.gen(ctx))
// Evaluation of non-deterministic expressions can't be deferred.
val nonDeterministicAttrs = projectList.zip(output).filter(!_._1.deterministic).unzip._2
s"""
| ${output.map(_.code).mkString("\n")}
|
| ${consume(ctx, output)}
|${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
|${consume(ctx, resultVars)}
""".stripMargin
}

Expand Down Expand Up @@ -89,11 +100,10 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
s""
}
s"""
| ${eval.code}
| if ($nullCheck ${eval.value}) {
| $numOutput.add(1);
| ${consume(ctx, ctx.currentVars)}
| }
|${eval.code}
|if (!($nullCheck ${eval.value})) continue;
|$numOutput.add(1);
|${consume(ctx, ctx.currentVars)}
""".stripMargin
}

Expand Down Expand Up @@ -228,15 +238,13 @@ case class Range(
| }
| }
|
| while (!$overflow && $checkEnd) {
| while (!$overflow && $checkEnd && !shouldStop()) {
| long $value = $number;
| $number += ${step}L;
| if ($number < $value ^ ${step}L < 0) {
| $overflow = true;
| }
| ${consume(ctx, Seq(ev))}
|
| if (shouldStop()) return;
| }
""".stripMargin
}
Expand Down
Loading