Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
*/
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
this(toBoundExprs(expressions, inputSchema))

override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
toBoundExprs(in, inputSchema)

def generate(
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])

protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
in.map(BindReferences.bindReference(_, inputSchema))
toBoundExprs(in, inputSchema)

/**
* Creates a code gen ordering for sorting this schema, in ascending order.
Expand Down Expand Up @@ -188,7 +188,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder])
extends Ordering[InternalRow] with KryoSerializable {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
this(toBoundExprs(ordering, inputSchema))

@transient
private[this] var generatedOrdering = GenerateOrdering.generate(ordering)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
toBoundExprs(in, inputSchema)

private def createCodeForStruct(
ctx: CodegenContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
toBoundExprs(in, inputSchema)

def generate(
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
this(toBoundExprs(ordering, inputSchema))

def compare(a: InternalRow, b: InternalRow): Int = {
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ package object expressions {
/**
* A helper function to bind given expressions to an input schema.
*/
def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = {
exprs.map(BindReferences.bindReference(_, inputSchema))
def toBoundExprs[A <: Expression](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to minimize the chance that future changes suffer from the same issue. In order to do that we should provide API in a logical place, it does not make a whole lot of sense to me that I need to look in package.scala to find a more performant version of BindReferences.bindReference(..) for a seq. Can we move this function to BindReference and name it bindReferences?

exprs: Seq[A],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent

inputSchema: Seq[Attribute]): Seq[A] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just changing this to AttributeSeq?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes this function is called with a zero-length exprs (df.count, for example). I am attempting to avoid constructing the AttributeSeq in that case, because AttributeSeq's constructor eagerly builds a data structure based on the attributes (private val qualified3Part).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How expensive is it to build those structures on empty data? We could consider cachin an empty attribute seq and use that when the there are no attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that just because exprs is empty, that would not necessarily mean inputSchema is empty. But experiments seem to indicate that inputSchema is also empty. So building the AttributeSeq would be extremely low-cost. I guess there is no reason to lazily build it.

lazy val inputSchemaAttrSeq: AttributeSeq = inputSchema
exprs.map(BindReferences.bindReference(_, inputSchemaAttrSeq))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ case class HashAggregateExec(
}
}
ctx.currentVars = bufVars ++ input
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
val boundUpdateExpr = toBoundExprs(updateExpr, inputAttrs)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand Down Expand Up @@ -825,7 +825,7 @@ case class HashAggregateExec(

val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val boundUpdateExpr = toBoundExprs(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand All @@ -849,7 +849,7 @@ case class HashAggregateExec(
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val boundUpdateExpr = toBoundExprs(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ trait HashJoin {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
.map(BindReferences.bindReference(_, right.output))
val lkeys = toBoundExprs(HashJoin.rewriteKeyExpr(leftKeys), left.output)
val rkeys = toBoundExprs(HashJoin.rewriteKeyExpr(rightKeys), right.output)
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ case class SortMergeJoinExec(
input: Seq[Attribute]): Seq[ExprCode] = {
ctx.INPUT_ROW = row
ctx.currentVars = null
keys.map(BindReferences.bindReference(_, input).genCode(ctx))
val inputAttributeSeq: AttributeSeq = input
keys.map(BindReferences.bindReference(_, inputAttributeSeq).genCode(ctx))
Copy link
Contributor

@mgaido91 mgaido91 Jan 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not toBoundExprs(keys, input).map(_.genCode(ctx))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be...

}

private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
Expand Down