Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])

// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
// sorts of things in its closure.
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
override def toSeq: Seq[Attribute] = {
// We need to keep a deterministic output order for `baseSet` because this affects a variable
// order in generated code (e.g., `GenerateColumnAccessor`).
// See SPARK-18394 for details.
baseSet.map(_.a).toSeq.sortBy { a => (a.name, a.exprId.id) }
}

override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,44 @@ class AttributeSetSuite extends SparkFunSuite {
assert(aSet == aSet)
assert(aSet == AttributeSet(aUpper :: Nil))
}

test("SPARK-18394 keep a deterministic output order along with attribute names and exprIds") {
// Checks a simple case
val attrSeqA = {
val attr1 = AttributeReference("c1", IntegerType)(exprId = ExprId(1098))
val attr2 = AttributeReference("c2", IntegerType)(exprId = ExprId(107))
val attr3 = AttributeReference("c3", IntegerType)(exprId = ExprId(838))
val attrSetA = AttributeSet(attr1 :: attr2 :: attr3 :: Nil)

val attr4 = AttributeReference("c4", IntegerType)(exprId = ExprId(389))
val attr5 = AttributeReference("c5", IntegerType)(exprId = ExprId(89329))

val attrSetB = AttributeSet(attr4 :: attr5 :: Nil)
(attrSetA ++ attrSetB).toSeq.map(_.name)
}

val attrSeqB = {
val attr1 = AttributeReference("c1", IntegerType)(exprId = ExprId(392))
val attr2 = AttributeReference("c2", IntegerType)(exprId = ExprId(92))
val attr3 = AttributeReference("c3", IntegerType)(exprId = ExprId(87))
val attrSetA = AttributeSet(attr1 :: attr2 :: attr3 :: Nil)

val attr4 = AttributeReference("c4", IntegerType)(exprId = ExprId(9023920))
val attr5 = AttributeReference("c5", IntegerType)(exprId = ExprId(522))
val attrSetB = AttributeSet(attr4 :: attr5 :: Nil)

(attrSetA ++ attrSetB).toSeq.map(_.name)
}

assert(attrSeqA === attrSeqB)

// Checks the same column names having different exprIds
val attr1 = AttributeReference("c", IntegerType)(exprId = ExprId(1098))
val attr2 = AttributeReference("c", IntegerType)(exprId = ExprId(107))
val attrSetA = AttributeSet(attr1 :: attr2 :: Nil)
val attr3 = AttributeReference("c", IntegerType)(exprId = ExprId(389))
val attrSetB = AttributeSet(attr3 :: Nil)

assert((attrSetA ++ attrSetB).toSeq === attr2 :: attr3 :: attr1 :: Nil)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,13 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
}.head

assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch")
assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch")

// Scanned columns in `HiveTableScanExec` are generated by the `pruneFilterProject` method
// in `SparkPlanner` that internally uses `AttributeSet.toSeq`.
// Since we change an output order of `AttributeSet.toSeq` in SPARK-18394,
// we need to sort column names for a test below.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about?

Scanned columns in HiveTableScanExec are generated by the pruneFilterProject method in SparkPlanner. This method internally uses AttributeSet.toSeq, in which the returned output columns are sorted by the names and expression ids.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look good, I'll update soon.

assert(actualScannedColumns.sorted === expectedScannedColumns.sorted,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment to explain where we call AttributeSet.toSeq?

"Scanned columns mismatch")

val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted
val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted
Expand Down