Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ case class AttributeReference(
def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId

override def equals(other: Any): Boolean = other match {
case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType
case ar: AttributeReference =>
name == ar.name && dataType == ar.dataType && nullable == ar.nullable &&
metadata == ar.metadata && exprId == ar.exprId && qualifiers == ar.qualifiers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile , can you send a follow-up PR to also update the hashCode according to this? thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will do it tonight.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
import scala.collection.mutable.ArrayBuffer

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
Expand Down Expand Up @@ -244,12 +244,12 @@ private[sql] object Expand {
*/
private def buildNonSelectExprSet(
bitmask: Int,
exprs: Seq[Expression]): OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)
exprs: Seq[Expression]): ArrayBuffer[Expression] = {
val set = new ArrayBuffer[Expression](2)

var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
if (((bitmask >> bit) & 1) == 0) set += exprs(bit)
bit -= 1
}

Expand Down Expand Up @@ -279,7 +279,7 @@ private[sql] object Expand {

(child.output :+ gid).map(expr => expr transformDown {
// TODO this causes a problem when a column is used both for grouping and aggregation.
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
case x: Expression if nonSelectedGroupExprSet.exists(_.semanticEquals(x)) =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,17 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredDistribution(requiredClustering) =>
expressions.toSet.subsetOf(requiredClustering.toSet)
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}

override def compatibleWith(other: Partitioning): Boolean = other match {
case o: HashPartitioning => this == o
case o: HashPartitioning => this.semanticEquals(o)
case _ => false
}

override def guarantees(other: Partitioning): Boolean = other match {
case o: HashPartitioning => this == o
case o: HashPartitioning => this.semanticEquals(o)
case _ => false
}

Expand Down Expand Up @@ -276,17 +276,17 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering) =>
ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet)
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}

override def compatibleWith(other: Partitioning): Boolean = other match {
case o: RangePartitioning => this == o
case o: RangePartitioning => this.semanticEquals(o)
case _ => false
}

override def guarantees(other: Partitioning): Boolean = other match {
case o: RangePartitioning => this == o
case o: RangePartitioning => this.semanticEquals(o)
case _ => false
}
}
Expand Down