Skip to content

Commit 6400eb2

Browse files
committed
Simplify the solution.
1 parent 7428fd4 commit 6400eb2

File tree

1 file changed

+19
-25
lines changed

1 file changed

+19
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -174,28 +174,29 @@ trait CodegenSupport extends SparkPlan {
174174
input: Seq[ExprCode],
175175
row: String = null): String = {
176176
ctx.freshNamePrefix = variablePrefix
177+
val realUsedInput =
178+
if (row != null && consumeUnsafeRow) {
179+
// If this SparkPlan consumes UnsafeRow and there is an UnsafeRow passed in,
180+
// we don't need to evaluate inputs because doConsume will directly consume the UnsafeRow.
181+
AttributeSet.empty
182+
} else {
183+
usedInputs
184+
}
185+
177186
val inputVars =
178187
if (row != null) {
179-
if (!consumeUnsafeRow) {
180-
// If this SparkPlan can't consume UnsafeRow and there is an UnsafeRow,
181-
// we extract the columns from the row and call doConsume.
182-
ctx.currentVars = null
183-
ctx.INPUT_ROW = row
184-
child.output.zipWithIndex.map { case (attr, i) =>
185-
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
186-
}
187-
} else {
188-
// If this SparkPlan consumes UnsafeRow and there is an UnsafeRow,
189-
// we don't need to unpack variables from the row.
190-
Seq.empty
188+
ctx.currentVars = null
189+
ctx.INPUT_ROW = row
190+
child.output.zipWithIndex.map { case (attr, i) =>
191+
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
191192
}
192193
} else {
193194
input
194195
}
195196
s"""
196197
|
197198
|/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
198-
|${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
199+
|${evaluateRequiredVariables(child.output, inputVars, realUsedInput)}
199200
|${doConsume(ctx, inputVars, row)}
200201
""".stripMargin
201202
}
@@ -245,19 +246,12 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
245246
val input = ctx.freshName("input")
246247
// Right now, InputAdapter is only used when there is one upstream.
247248
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
248-
val row = ctx.freshName("row")
249249

250-
// If the parent of this InputAdapter can't consume UnsafeRow,
251-
// we unpack variables from the row.
252-
val columns: Seq[ExprCode] = if (!this.parent.consumeUnsafeRow) {
253-
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
254-
ctx.INPUT_ROW = row
255-
ctx.currentVars = null
256-
exprs.map(_.gen(ctx))
257-
} else {
258-
// If the parent consumes UnsafeRow, we don't need to unpack the row.
259-
Seq.empty
260-
}
250+
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
251+
val row = ctx.freshName("row")
252+
ctx.INPUT_ROW = row
253+
ctx.currentVars = null
254+
val columns = exprs.map(_.gen(ctx))
261255
s"""
262256
| while (!shouldStop() && $input.hasNext()) {
263257
| InternalRow $row = (InternalRow) $input.next();

0 commit comments

Comments
 (0)