Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.{mutable, GenTraversableOnce}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object ExpressionSet {
Expand Down Expand Up @@ -55,13 +55,22 @@ object ExpressionSet {
* For non-deterministic expressions, they are always considered as not contained in the [[Set]].
* On adding a non-deterministic expression, simply append it to the original expressions.
* This is consistent with how we define `semanticEquals` between two expressions.
*
* The constructor of this class is protected so caller can only initialize an Expression from
* empty, then build it using `add` and `remove` methods. So every instance of this class holds the
* invariant that:
* 1. Every expr `e` in `baseSet` satisfies `e.deterministic && e.canonicalized == e`
* 2. Every deterministic expr `e` in `originals` satisfies that `e.canonicalized` is already
* accessed.
*/
class ExpressionSet protected(
private val baseSet: mutable.Set[Expression] = new mutable.HashSet,
private val originals: mutable.Buffer[Expression] = new ArrayBuffer)
extends Iterable[Expression] {
private var originals: mutable.Buffer[Expression] = new ArrayBuffer)
extends scala.collection.Set[Expression]
with scala.collection.SetLike[Expression, ExpressionSet] {

// Note: this class supports Scala 2.12. A parallel source tree has a 2.13 implementation.
override def empty: ExpressionSet = new ExpressionSet()

protected def add(e: Expression): Unit = {
if (!e.deterministic) {
Expand All @@ -74,49 +83,37 @@ class ExpressionSet protected(

protected def remove(e: Expression): Unit = {
if (e.deterministic) {
baseSet --= baseSet.filter(_ == e.canonicalized)
originals --= originals.filter(_.canonicalized == e.canonicalized)
baseSet.remove(e.canonicalized)
originals = originals.filter(!_.semanticEquals(e))
Copy link
Contributor Author

@minyyy minyyy Apr 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two changes happening in this line.

  1. Changes to originals = originals.filter(...). This is O(n) as the previous implementation is O(mn).
  2. Uses semanticEquals instead of the previous _.canonicalized == e.canonicalized as the condition. semanticEquals can short circuit and avoid unnecessary access of canonicalized for non-deterministic expressions.

By invariant 2, the current implementation should not evaluate canonicalized at all.

}
}

def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)
override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)

override def filter(p: Expression => Boolean): ExpressionSet = {
val newBaseSet = baseSet.filter(e => p(e.canonicalized))
val newBaseSet = baseSet.filter(e => p(e))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By invariant 1: .canonicalized is not needed and can cause performance issue.

val newOriginals = originals.filter(e => p(e.canonicalized))
new ExpressionSet(newBaseSet, newOriginals)
}

override def filterNot(p: Expression => Boolean): ExpressionSet = {
val newBaseSet = baseSet.filterNot(e => p(e.canonicalized))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By invariant 1: .canonicalized is not needed and can cause performance issue.

val newBaseSet = baseSet.filterNot(e => p(e))
val newOriginals = originals.filterNot(e => p(e.canonicalized))
new ExpressionSet(newBaseSet, newOriginals)
}

def +(elem: Expression): ExpressionSet = {
override def +(elem: Expression): ExpressionSet = {
val newSet = clone()
newSet.add(elem)
newSet
}

def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = {
val newSet = clone()
elems.foreach(newSet.add)
newSet
}

def -(elem: Expression): ExpressionSet = {
override def -(elem: Expression): ExpressionSet = {
val newSet = clone()
newSet.remove(elem)
newSet
}

def --(elems: GenTraversableOnce[Expression]): ExpressionSet = {
val newSet = clone()
elems.foreach(newSet.remove)
newSet
}

def map(f: Expression => Expression): ExpressionSet = {
val newSet = new ExpressionSet()
this.iterator.foreach(elem => newSet.add(f(elem)))
Expand All @@ -129,21 +126,9 @@ class ExpressionSet protected(
newSet
}

def iterator: Iterator[Expression] = originals.iterator

def union(that: ExpressionSet): ExpressionSet = {
val newSet = clone()
that.iterator.foreach(newSet.add)
newSet
}

def subsetOf(that: ExpressionSet): Boolean = this.iterator.forall(that.contains)

def intersect(that: ExpressionSet): ExpressionSet = this.filter(that.contains)

def diff(that: ExpressionSet): ExpressionSet = this -- that
override def iterator: Iterator[Expression] = originals.iterator

def apply(elem: Expression): Boolean = this.contains(elem)
override def apply(elem: Expression): Boolean = this.contains(elem)

override def equals(obj: Any): Boolean = obj match {
case other: ExpressionSet => this.baseSet == other.baseSet
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection.{mutable, IterableFactory, IterableOps}
import scala.collection.mutable.ArrayBuffer

object ExpressionSet {
/** Constructs a new [[ExpressionSet]] by applying [[Canonicalize]] to `expressions`. */
def apply(expressions: IterableOnce[Expression]): ExpressionSet = {
val set = new ExpressionSet()
expressions.iterator.foreach(set.add)
set
}

def apply(): ExpressionSet = {
new ExpressionSet()
}
}

/**
* A [[Set]] where membership is determined based on determinacy and a canonical representation of
* an [[Expression]] (i.e. one that attempts to ignore cosmetic differences).
* See [[Canonicalize]] for more details.
*
* Internally this set uses the canonical representation, but keeps also track of the original
* expressions to ease debugging. Since different expressions can share the same canonical
* representation, this means that operations that extract expressions from this set are only
* guaranteed to see at least one such expression. For example:
*
* {{{
* val set = ExpressionSet(a + 1, 1 + a)
*
* set.iterator => Iterator(a + 1)
* set.contains(a + 1) => true
* set.contains(1 + a) => true
* set.contains(a + 2) => false
* }}}
*
* For non-deterministic expressions, they are always considered as not contained in the [[Set]].
* On adding a non-deterministic expression, simply append it to the original expressions.
* This is consistent with how we define `semanticEquals` between two expressions.
*
* The constructor of this class is protected so caller can only initialize an Expression from
* empty, then build it using `add` and `remove` methods. So every instance of this class holds the
* invariant that:
* 1. Every expr `e` in `baseSet` satisfies `e.deterministic && e.canonicalized == e`
* 2. Every deterministic expr `e` in `originals` satisfies that `e.canonicalized` is already
* accessed.
*/
class ExpressionSet protected(
private val baseSet: mutable.Set[Expression] = new mutable.HashSet,
private var originals: mutable.Buffer[Expression] = new ArrayBuffer)
extends scala.collection.Set[Expression]
with scala.collection.SetOps[Expression, scala.collection.Set, ExpressionSet] {

override protected def fromSpecific(coll: IterableOnce[Expression]): ExpressionSet = {
val set = new ExpressionSet()
coll.iterator.foreach(set.add)
set
}

override protected def newSpecificBuilder: mutable.Builder[Expression, ExpressionSet] =
new mutable.Builder[Expression, ExpressionSet] {
var expr_set: ExpressionSet = new ExpressionSet()
def clear(): Unit = expr_set = new ExpressionSet()
def result(): ExpressionSet = expr_set
def addOne(expr: Expression): this.type = {
expr_set.add(expr)
this
}
}

override def empty: ExpressionSet = new ExpressionSet()

override def diff(that: scala.collection.Set[Expression]): ExpressionSet = this -- that

protected def add(e: Expression): Unit = {
if (!e.deterministic) {
originals += e
} else if (!baseSet.contains(e.canonicalized)) {
baseSet.add(e.canonicalized)
originals += e
}
}

protected def remove(e: Expression): Unit = {
if (e.deterministic) {
baseSet.remove(e.canonicalized)
originals = originals.filter(!_.semanticEquals(e))
}
}

override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)

override def filter(p: Expression => Boolean): ExpressionSet = {
val newBaseSet = baseSet.filter(e => p(e))
val newOriginals = originals.filter(e => p(e.canonicalized))
new ExpressionSet(newBaseSet, newOriginals)
}

override def filterNot(p: Expression => Boolean): ExpressionSet = {
val newBaseSet = baseSet.filterNot(e => p(e))
val newOriginals = originals.filterNot(e => p(e.canonicalized))
new ExpressionSet(newBaseSet, newOriginals)
}

override def +(elem: Expression): ExpressionSet = {
val newSet = clone()
newSet.add(elem)
newSet
}

override def -(elem: Expression): ExpressionSet = {
val newSet = clone()
newSet.remove(elem)
newSet
}

def map(f: Expression => Expression): ExpressionSet = {
val newSet = new ExpressionSet()
this.iterator.foreach(elem => newSet.add(f(elem)))
newSet
}

def flatMap(f: Expression => Iterable[Expression]): ExpressionSet = {
val newSet = new ExpressionSet()
this.iterator.foreach(f(_).foreach(newSet.add))
newSet
}

override def iterator: Iterator[Expression] = originals.iterator

override def equals(obj: Any): Boolean = obj match {
case other: ExpressionSet => this.baseSet == other.baseSet
case _ => false
}

override def hashCode(): Int = baseSet.hashCode()

override def clone(): ExpressionSet = new ExpressionSet(baseSet.clone(), originals.clone())

/**
* Returns a string containing both the post [[Canonicalize]] expressions and the original
* expressions in this set.
*/
def toDebugString: String =
s"""
|baseSet: ${baseSet.mkString(", ")}
|originals: ${originals.mkString(", ")}
""".stripMargin
}