Skip to content

Commit c4787a3

Browse files
marmbrusrxin
authored andcommitted
[SPARK-3194][SQL] Add AttributeSet to fix bugs with invalid comparisons of AttributeReferences
It is common to want to describe sets of attributes that are in various parts of a query plan. However, the semantics of putting `AttributeReference` objects into a standard Scala `Set` result in subtle bugs when references differ cosmetically. For example, with case insensitive resolution it is possible to have two references to the same attribute whose names are not equal. In this PR I introduce a new abstraction, an `AttributeSet`, which performs all comparisons using the globally unique `ExpressionId` instead of case class equality. (There is already a related class, [`AttributeMap`](https://github.com/marmbrus/spark/blob/inMemStats/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala#L32)) This new type of set is used to fix a bug in the optimizer where needed attributes were getting projected away underneath join operators. I also took this opportunity to refactor the expression and query plan base classes. In all but one instance the logic for computing the `references` of an `Expression` were the same. Thus, I moved this logic into the base class. For query plans the semantics of the `references` method were ill defined (is it the references output? or is it those used by expression evaluation? or what?). As a result, this method wasn't really used very much. So, I removed it. TODO: - [x] Finish scala doc for `AttributeSet` - [x] Scan the code for other instances of `Set[Attribute]` and refactor them. - [x] Finish removing `references` from `QueryPlan` Author: Michael Armbrust <[email protected]> Closes #2109 from marmbrus/attributeSets and squashes the following commits: 1c0dae5 [Michael Armbrust] work on serialization bug. 9ba868d [Michael Armbrust] Merge remote-tracking branch 'origin/master' into attributeSets 3ae5288 [Michael Armbrust] review comments 40ce7f6 [Michael Armbrust] style d577cc7 [Michael Armbrust] Scaladoc cae5d22 [Michael Armbrust] remove more references implementations d6e16be [Michael Armbrust] Remove more instances of "def references" and normal sets of attributes. fc26b49 [Michael Armbrust] Add AttributeSet class, remove references from Expression.
1 parent 1208f72 commit c4787a3

File tree

35 files changed

+166
-123
lines changed

35 files changed

+166
-123
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
132132
case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
133133
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
134134
val resolved = unresolved.flatMap(child.resolveChildren)
135-
val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet
135+
val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })
136136

137137
val missingInProject = requiredAttributes -- p.output
138138
if (missingInProject.nonEmpty) {
@@ -152,8 +152,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
152152
)
153153

154154
logDebug(s"Grouping expressions: $groupingRelation")
155-
val resolved = unresolved.flatMap(groupingRelation.resolve).toSet
156-
val missingInAggs = resolved -- a.outputSet
155+
val resolved = unresolved.flatMap(groupingRelation.resolve)
156+
val missingInAggs = resolved.filterNot(a.outputSet.contains)
157157
logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
158158
if (missingInAggs.nonEmpty) {
159159
// Add missing grouping exprs and then project them away after the sort.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
6666
override def dataType = throw new UnresolvedException(this, "dataType")
6767
override def foldable = throw new UnresolvedException(this, "foldable")
6868
override def nullable = throw new UnresolvedException(this, "nullable")
69-
override def references = children.flatMap(_.references).toSet
7069
override lazy val resolved = false
7170

7271
// Unresolved functions are transient at compile time and don't get evaluated during execution.
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
protected class AttributeEquals(val a: Attribute) {
21+
override def hashCode() = a.exprId.hashCode()
22+
override def equals(other: Any) = other match {
23+
case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
24+
case otherAttribute => false
25+
}
26+
}
27+
28+
object AttributeSet {
29+
/** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */
30+
def apply(baseSet: Seq[Attribute]) = {
31+
new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
32+
}
33+
}
34+
35+
/**
36+
* A Set designed to hold [[AttributeReference]] objects, that performs equality checking using
37+
* expression id instead of standard java equality. Using expression id means that these
38+
* sets will correctly test for membership, even when the AttributeReferences in question differ
39+
* cosmetically (e.g., the names have different capitalizations).
40+
*
41+
* Note that we do not override equality for Attribute references as it is really weird when
42+
* `AttributeReference("a"...) == AttrributeReference("b", ...)`. This tactic leads to broken tests,
43+
* and also makes doing transformations hard (we always try keep older trees instead of new ones
44+
* when the transformation was a no-op).
45+
*/
46+
class AttributeSet private (val baseSet: Set[AttributeEquals])
47+
extends Traversable[Attribute] with Serializable {
48+
49+
/** Returns true if the members of this AttributeSet and other are the same. */
50+
override def equals(other: Any) = other match {
51+
case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains)
52+
case _ => false
53+
}
54+
55+
/** Returns true if this set contains an Attribute with the same expression id as `elem` */
56+
def contains(elem: NamedExpression): Boolean =
57+
baseSet.contains(new AttributeEquals(elem.toAttribute))
58+
59+
/** Returns a new [[AttributeSet]] that contains `elem` in addition to the current elements. */
60+
def +(elem: Attribute): AttributeSet = // scalastyle:ignore
61+
new AttributeSet(baseSet + new AttributeEquals(elem))
62+
63+
/** Returns a new [[AttributeSet]] that does not contain `elem`. */
64+
def -(elem: Attribute): AttributeSet =
65+
new AttributeSet(baseSet - new AttributeEquals(elem))
66+
67+
/** Returns an iterator containing all of the attributes in the set. */
68+
def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator
69+
70+
/**
71+
* Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in
72+
* `other`.
73+
*/
74+
def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet)
75+
76+
/**
77+
* Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found
78+
* in `other`.
79+
*/
80+
def --(other: Traversable[NamedExpression]) =
81+
new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
82+
83+
/**
84+
* Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found
85+
* in `other`.
86+
*/
87+
def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet)
88+
89+
/**
90+
* Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to
91+
* true.
92+
*/
93+
override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a)))
94+
95+
/**
96+
* Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in
97+
* `this` and `other`.
98+
*/
99+
def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet))
100+
101+
override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f)
102+
103+
// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
104+
// sorts of things in its closure.
105+
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
106+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
3232

3333
type EvaluatedType = Any
3434

35-
override def references = Set.empty
36-
3735
override def toString = s"input[$ordinal]"
3836

3937
override def eval(input: Row): Any = input(ordinal)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ abstract class Expression extends TreeNode[Expression] {
4141
*/
4242
def foldable: Boolean = false
4343
def nullable: Boolean
44-
def references: Set[Attribute]
44+
def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
4545

4646
/** Returns the result of evaluating this expression on a given input Row */
4747
def eval(input: Row = null): EvaluatedType
@@ -230,8 +230,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
230230

231231
override def foldable = left.foldable && right.foldable
232232

233-
override def references = left.references ++ right.references
234-
235233
override def toString = s"($left $symbol $right)"
236234
}
237235

@@ -242,5 +240,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
242240
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
243241
self: Product =>
244242

245-
override def references = child.references
243+
246244
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.types.DoubleType
2424
case object Rand extends LeafExpression {
2525
override def dataType = DoubleType
2626
override def nullable = false
27-
override def references = Set.empty
2827

2928
private[this] lazy val rand = new Random
3029

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
2424

2525
type EvaluatedType = Any
2626

27-
def references = children.flatMap(_.references).toSet
2827
def nullable = true
2928

3029
/** This method has been generated by this script

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ case object Descending extends SortDirection
3131
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
3232
with trees.UnaryNode[Expression] {
3333

34-
override def references = child.references
3534
override def dataType = child.dataType
3635
override def nullable = child.nullable
3736

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression {
3535
type EvaluatedType = DynamicRow
3636

3737
def nullable = false
38-
def references = children.toSet
38+
3939
def dataType = DynamicType
4040

4141
override def eval(input: Row): DynamicRow = input match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ abstract class AggregateFunction
7878

7979
/** Base should return the generic aggregate expression that this function is computing */
8080
val base: AggregateExpression
81-
override def references = base.references
81+
8282
override def nullable = base.nullable
8383
override def dataType = base.dataType
8484

@@ -89,7 +89,7 @@ abstract class AggregateFunction
8989
}
9090

9191
case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
92-
override def references = child.references
92+
9393
override def nullable = true
9494
override def dataType = child.dataType
9595
override def toString = s"MIN($child)"
@@ -119,7 +119,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
119119
}
120120

121121
case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
122-
override def references = child.references
122+
123123
override def nullable = true
124124
override def dataType = child.dataType
125125
override def toString = s"MAX($child)"
@@ -149,7 +149,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
149149
}
150150

151151
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
152-
override def references = child.references
152+
153153
override def nullable = false
154154
override def dataType = LongType
155155
override def toString = s"COUNT($child)"
@@ -166,7 +166,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
166166
def this() = this(null)
167167

168168
override def children = expressions
169-
override def references = expressions.flatMap(_.references).toSet
169+
170170
override def nullable = false
171171
override def dataType = LongType
172172
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
@@ -184,7 +184,6 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
184184
def this() = this(null)
185185

186186
override def children = expressions
187-
override def references = expressions.flatMap(_.references).toSet
188187
override def nullable = false
189188
override def dataType = ArrayType(expressions.head.dataType)
190189
override def toString = s"AddToHashSet(${expressions.mkString(",")})"
@@ -219,7 +218,6 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression
219218
def this() = this(null)
220219

221220
override def children = inputSet :: Nil
222-
override def references = inputSet.references
223221
override def nullable = false
224222
override def dataType = LongType
225223
override def toString = s"CombineAndCount($inputSet)"
@@ -248,7 +246,7 @@ case class CombineSetsAndCountFunction(
248246

249247
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
250248
extends AggregateExpression with trees.UnaryNode[Expression] {
251-
override def references = child.references
249+
252250
override def nullable = false
253251
override def dataType = child.dataType
254252
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -257,7 +255,7 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
257255

258256
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
259257
extends AggregateExpression with trees.UnaryNode[Expression] {
260-
override def references = child.references
258+
261259
override def nullable = false
262260
override def dataType = LongType
263261
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -266,7 +264,7 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
266264

267265
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
268266
extends PartialAggregate with trees.UnaryNode[Expression] {
269-
override def references = child.references
267+
270268
override def nullable = false
271269
override def dataType = LongType
272270
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -284,7 +282,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
284282
}
285283

286284
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
287-
override def references = child.references
285+
288286
override def nullable = false
289287
override def dataType = DoubleType
290288
override def toString = s"AVG($child)"
@@ -304,7 +302,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
304302
}
305303

306304
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
307-
override def references = child.references
305+
308306
override def nullable = false
309307
override def dataType = child.dataType
310308
override def toString = s"SUM($child)"
@@ -322,7 +320,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
322320
case class SumDistinct(child: Expression)
323321
extends AggregateExpression with trees.UnaryNode[Expression] {
324322

325-
override def references = child.references
323+
326324
override def nullable = false
327325
override def dataType = child.dataType
328326
override def toString = s"SUM(DISTINCT $child)"
@@ -331,7 +329,6 @@ case class SumDistinct(child: Expression)
331329
}
332330

333331
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
334-
override def references = child.references
335332
override def nullable = true
336333
override def dataType = child.dataType
337334
override def toString = s"FIRST($child)"

0 commit comments

Comments
 (0)