Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ case object UnresolvedSeed extends LeafExpression with Unevaluable {
*/
case class TempResolvedColumn(child: Expression, nameParts: Seq[String]) extends UnaryExpression
with Unevaluable {
override lazy val preCanonicalized = child.preCanonicalized
override lazy val canonicalized = child.canonicalized
override def dataType: DataType = child.dataType
override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ case class Cast(
override lazy val resolved: Boolean =
childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)

override lazy val preCanonicalized: Expression = {
val basic = withNewChildren(Seq(child.preCanonicalized)).asInstanceOf[Cast]
override lazy val canonicalized: Expression = {
val basic = withNewChildren(Seq(child.canonicalized)).asInstanceOf[Cast]
if (timeZoneId.isDefined && !needsTimeZone) {
basic.withTimeZone(null)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ case class DynamicPruningSubquery(

override def toString: String = s"dynamicpruning#${exprId.id} $conditionString"

override lazy val preCanonicalized: DynamicPruning = {
override lazy val canonicalized: DynamicPruning = {
copy(
pruningKey = pruningKey.preCanonicalized,
pruningKey = pruningKey.canonicalized,
buildQuery = buildQuery.canonicalized,
buildKeys = buildKeys.map(_.preCanonicalized),
buildKeys = buildKeys.map(_.canonicalized),
exprId = ExprId(0))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,43 +224,28 @@ abstract class Expression extends TreeNode[Expression] {
*/
def childrenResolved: Boolean = children.forall(_.resolved)

// Expression canonicalization is done in 2 phases:
// 1. Recursively canonicalize each node in the expression tree. This does not change the tree
// structure and is more like "node-local" canonicalization.
// 2. Find adjacent commutative operators in the expression tree, reorder them to get a
// static order and remove cosmetic variations. This may change the tree structure
// dramatically and is more like a "global" canonicalization.
//
// The first phase is done by `preCanonicalized`. It's a `lazy val` which recursively calls
// `preCanonicalized` on the children. This means that almost every node in the expression tree
// will instantiate the `preCanonicalized` variable, which is good for performance as you can
// reuse the canonicalization result of the children when you construct a new expression node.
//
// The second phase is done by `canonicalized`, which simply calls `Canonicalize` and is kind of
// the actual "user-facing API" of expression canonicalization. Only the root node of the
// expression tree will instantiate the `canonicalized` variable. This is different from
// `preCanonicalized`, because `canonicalized` does "global" canonicalization and most of the time
// you cannot reuse the canonicalization result of the children.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good that we now have a better and simplified version, but we should still have a detailed comment to explain the new workflow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a process comment in 6c68d8b, let me know if it needs more details.

Copy link
Contributor Author

@peter-toth peter-toth Sep 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've also updated the PR description.


/**
* An internal lazy val to implement expression canonicalization. It should only be called in
* `canonicalized`, or in subclass's `preCanonicalized` when the subclass overrides this lazy val
* to provide custom canonicalization logic.
*/
lazy val preCanonicalized: Expression = {
val canonicalizedChildren = children.map(_.preCanonicalized)
withNewChildren(canonicalizedChildren)
}

/**
* Returns an expression where a best effort attempt has been made to transform `this` in a way
* that preserves the result but removes cosmetic variations (case sensitivity, ordering for
* commutative operations, etc.) See [[Canonicalize]] for more details.
* commutative operations, etc.).
*
* `deterministic` expressions where `this.canonicalized == other.canonicalized` will always
* evaluate to the same result.
*
* The process of canonicalization is a one pass, bottum-up expression tree computation based on
* canonicalizing children before canonicalizing the current node. There is one exception though,
* as adjacent, same class [[CommutativeExpression]]s canonicalazion happens in a way that calling
* `canonicalized` on the root:
* 1. Gathers and canonicalizes the non-commutative (or commutative but not same class) child
* expressions of the adjacent expressions.
* 2. Reorder the canonicalized child expressions by their hashcode.
* This means that the lazy `cannonicalized` is called and computed only on the root of the
* adjacent expressions.
*/
lazy val canonicalized: Expression = Canonicalize.reorderCommutativeOperators(preCanonicalized)
lazy val canonicalized: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
withNewChildren(canonicalizedChildren)
}

/**
* Returns true when two expressions will always compute the same result, even if they differ
Expand Down Expand Up @@ -364,7 +349,7 @@ trait RuntimeReplaceable extends Expression {
// As this expression gets replaced at optimization with its `child" expression,
// two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions
// are semantically equal.
override lazy val preCanonicalized: Expression = replacement.preCanonicalized
override lazy val canonicalized: Expression = replacement.canonicalized

final override def eval(input: InternalRow = null): Any =
throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
Expand Down Expand Up @@ -1176,3 +1161,21 @@ trait ComplexTypeMergingExpression extends Expression {
trait UserDefinedExpression {
def name: String
}

trait CommutativeExpression extends Expression {
/** Collects adjacent commutative operations. */
private def gatherCommutative(
e: Expression,
f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = e match {
case c: CommutativeExpression if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f))
case other => other.canonicalized :: Nil
}

/**
* Reorders adjacent commutative operators such as [[And]] in the expression tree, according to
* the `hashCode` of non-commutative nodes, to remove cosmetic variations.
*/
protected def orderCommutative(
f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] =
gatherCommutative(this, f).sortBy(_.hashCode())
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ case class PythonUDF(

override def nullable: Boolean = true

override lazy val preCanonicalized: Expression = {
val canonicalizedChildren = children.map(_.preCanonicalized)
override lazy val canonicalized: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
// `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result.
this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ case class ScalaUDF(

override def name: String = udfName.getOrElse("UDF")

override lazy val preCanonicalized: Expression = {
override lazy val canonicalized: Expression = {
// SPARK-32307: `ExpressionEncoder` can't be canonicalized, and technically we don't
// need it to identify a `ScalaUDF`.
copy(children = children.map(_.preCanonicalized), inputEncoders = Nil, outputEncoder = None)
copy(children = children.map(_.canonicalized), inputEncoders = Nil, outputEncoder = None)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ case class AggregateExpression(
def filterAttributes: AttributeSet = filter.map(_.references).getOrElse(AttributeSet.empty)

// We compute the same thing regardless of our final result.
override lazy val preCanonicalized: Expression = {
override lazy val canonicalized: Expression = {
val normalizedAggFunc = mode match {
// For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers,
// and the actual children of `aggregateFunction` is not used, here we normalize the expr id.
Expand All @@ -137,10 +137,10 @@ case class AggregateExpression(
}

AggregateExpression(
normalizedAggFunc.preCanonicalized.asInstanceOf[AggregateFunction],
normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction],
mode,
isDistinct,
filter.map(_.preCanonicalized),
filter.map(_.canonicalized),
ExprId(0))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ object BinaryArithmetic {
case class Add(
left: Expression,
right: Expression,
evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic {
evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic
with CommutativeExpression {

def this(left: Expression, right: Expression) =
this(left, right, EvalMode.fromSQLConf(SQLConf.get))
Expand Down Expand Up @@ -473,6 +474,11 @@ case class Add(

override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Add =
copy(left = newLeft, right = newRight)

override lazy val canonicalized: Expression = {
// TODO: do not reorder consecutive `Add`s with different `evalMode`
orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode))
}
}

@ExpressionDescription(
Expand Down Expand Up @@ -563,7 +569,8 @@ case class Subtract(
case class Multiply(
left: Expression,
right: Expression,
evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic {
evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic
with CommutativeExpression {

def this(left: Expression, right: Expression) =
this(left, right, EvalMode.fromSQLConf(SQLConf.get))
Expand Down Expand Up @@ -612,6 +619,11 @@ case class Multiply(

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Multiply = copy(left = newLeft, right = newRight)

override lazy val canonicalized: Expression = {
// TODO: do not reorder consecutive `Multiply`s with different `evalMode`
orderCommutative({ case Multiply(l, r, _) => Seq(l, r) }).reduce(Multiply(_, _, evalMode))
}
}

// Common base trait for Divide and Remainder, since these two classes are almost identical
Expand Down Expand Up @@ -1176,7 +1188,8 @@ case class Pmod(
""",
since = "1.5.0",
group = "math_funcs")
case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression {
case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression
with CommutativeExpression {

override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand Down Expand Up @@ -1239,6 +1252,10 @@ case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Least =
copy(children = newChildren)

override lazy val canonicalized: Expression = {
Least(orderCommutative({ case Least(children) => children }))
}
}

/**
Expand All @@ -1254,7 +1271,8 @@ case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression
""",
since = "1.5.0",
group = "math_funcs")
case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression {
case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression
with CommutativeExpression {

override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand Down Expand Up @@ -1317,4 +1335,8 @@ case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpress

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Greatest =
copy(children = newChildren)

override lazy val canonicalized: Expression = {
Greatest(orderCommutative({ case Greatest(children) => children }))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ import org.apache.spark.sql.types._
""",
since = "1.4.0",
group = "bitwise_funcs")
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic
with CommutativeExpression {

protected override val evalMode: EvalMode.Value = EvalMode.LEGACY

Expand All @@ -59,6 +60,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): BitwiseAnd = copy(left = newLeft, right = newRight)

override lazy val canonicalized: Expression = {
orderCommutative({ case BitwiseAnd(l, r) => Seq(l, r) }).reduce(BitwiseAnd)
}
}

/**
Expand All @@ -75,7 +80,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
""",
since = "1.4.0",
group = "bitwise_funcs")
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic
with CommutativeExpression {

protected override val evalMode: EvalMode.Value = EvalMode.LEGACY

Expand All @@ -98,6 +104,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): BitwiseOr = copy(left = newLeft, right = newRight)

override lazy val canonicalized: Expression = {
orderCommutative({ case BitwiseOr(l, r) => Seq(l, r) }).reduce(BitwiseOr)
}
}

/**
Expand All @@ -114,7 +124,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
""",
since = "1.4.0",
group = "bitwise_funcs")
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic
with CommutativeExpression {

protected override val evalMode: EvalMode.Value = EvalMode.LEGACY

Expand All @@ -137,6 +148,10 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): BitwiseXor = copy(left = newLeft, right = newRight)

override lazy val canonicalized: Expression = {
orderCommutative({ case BitwiseXor(l, r) => Seq(l, r) }).reduce(BitwiseXor)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]

lazy val childSchema = child.dataType.asInstanceOf[StructType]

override lazy val preCanonicalized: Expression = {
copy(child = child.preCanonicalized, name = None)
override lazy val canonicalized: Expression = {
copy(child = child.canonicalized, name = None)
}

override def dataType: DataType = childSchema(ordinal).dataType
Expand Down
Loading