Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Extended CollationSuite and added tests where SortMergeJoin is forced
  • Loading branch information
vladanvasi-db committed Nov 6, 2024
commit 81b08fc99408819034edfe4af2c40f5903adafcd
167 changes: 162 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}
}

test("hash join should be used for collated strings") {
test("hash join should be used for collated strings if sort merge join is not forced") {
val t1 = "T_1"
val t2 = "T_2"

Expand Down Expand Up @@ -1598,11 +1598,38 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
case b: HashJoin => b.leftKeys.head
}.head.isInstanceOf[CollationKey])
}

// Disable broadcast join to force sort merge join.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it maybe be easier to run the tests with default and -1 values of the conf, and then just assert that different joins are used based on the conf's value?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like we would avoid a lot of code duplication with this approach

Copy link
Contributor

@uros-db uros-db Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about smth like this

Seq("-1", "1").foreach(val =>
  withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> val) {
    ...

i.e. iterating over possible vals, to reduce duplication
can also conditionally collect join plan nodes, based on val

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did exactly this, the asserts for SparkPlan nodes are refactored, however, the collationKey check in the plan could not be refactored like the asserts, so there are some duplications in the code, but not significant.

val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)

val queryPlan = df.queryExecution.executedPlan

// confirm that sort merge join is used instead of hash join
assert(
collectFirst(queryPlan) {
case _: HashJoin => ()
}.isEmpty
)
assert(
collectFirst(queryPlan) {
case _: SortMergeJoinExec => ()
}.nonEmpty
)

// Only if collation doesn't support binary equality, collation key should be injected.
if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) {
assert(queryPlan.toString().contains("collationkey"))
} else {
assert(!queryPlan.toString().contains("collationkey"))
}
}
}
})
}

test("hash join should be used for arrays of collated strings") {
test("hash join should be used for arrays of collated strings if sort merge join is not forced") {
val t1 = "T_1"
val t2 = "T_2"

Expand Down Expand Up @@ -1656,11 +1683,39 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
case b: BroadcastHashJoinExec => b.leftKeys.head
}.head.isInstanceOf[ArrayTransform])
}

// Disable broadcast join to force sort merge join.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)

val queryPlan = df.queryExecution.executedPlan

// confirm that sort merge join is used instead of hash join
assert(
collectFirst(queryPlan) {
case _: HashJoin => ()
}.isEmpty
)
assert(
collectFirst(queryPlan) {
case _: SortMergeJoinExec => ()
}.nonEmpty
)

// Only if collation doesn't support binary equality, collation key should be injected.
if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) {
assert(queryPlan.toString().contains("collationkey"))
} else {
assert(!queryPlan.toString().contains("collationkey"))
}
}
}
})
}

test("hash join should be used for arrays of arrays of collated strings") {
test("hash join should be used for arrays of arrays of collated strings " +
"if sort merge join is not forced") {
val t1 = "T_1"
val t2 = "T_2"

Expand Down Expand Up @@ -1718,11 +1773,38 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
case b: BroadcastHashJoinExec => b.leftKeys.head
}.head.isInstanceOf[ArrayTransform])
}

// Disable broadcast join to force sort merge join.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)

val queryPlan = df.queryExecution.executedPlan

// confirm that sort merge join is used instead of hash join
assert(
collectFirst(queryPlan) {
case _: HashJoin => ()
}.isEmpty
)
assert(
collectFirst(queryPlan) {
case _: SortMergeJoinExec => ()
}.nonEmpty
)

// Only if collation doesn't support binary equality, collation key should be injected.
if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) {
assert(queryPlan.toString().contains("collationkey"))
} else {
assert(!queryPlan.toString().contains("collationkey"))
}
}
}
})
}

test("hash join should respect collation for struct of strings") {
test("hash and sort merge join should respect collation for struct of strings") {
val t1 = "T_1"
val t2 = "T_2"

Expand Down Expand Up @@ -1771,11 +1853,39 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
} else {
assert(!queryPlan.toString().contains("collationkey"))
}

// Disable broadcast join to force sort merge join.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)

val queryPlan = df.queryExecution.executedPlan

// confirm that sort merge join is used instead of hash join
assert(
collectFirst(queryPlan) {
case _: HashJoin => ()
}.isEmpty
)
assert(
collectFirst(queryPlan) {
case _: SortMergeJoinExec => ()
}.nonEmpty
)

// Only if collation doesn't support binary equality, collation key should be injected.
if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) {
assert(queryPlan.toString().contains("collationkey"))
} else {
assert(!queryPlan.toString().contains("collationkey"))
}
}
}
})
}

test("hash join should respect collation for struct of array of struct of strings") {
test("hash and sort merge join should respect collation " +
"for struct of array of struct of strings") {
val t1 = "T_1"
val t2 = "T_2"

Expand Down Expand Up @@ -1830,6 +1940,33 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
} else {
assert(!queryPlan.toString().contains("collationkey"))
}

// Disable broadcast join to force sort merge join.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)

val queryPlan = df.queryExecution.executedPlan

// confirm that sort merge join is used instead of hash join
assert(
collectFirst(queryPlan) {
case _: HashJoin => ()
}.isEmpty
)
assert(
collectFirst(queryPlan) {
case _: SortMergeJoinExec => ()
}.nonEmpty
)

// Only if collation doesn't support binary equality, collation key should be injected.
if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) {
assert(queryPlan.toString().contains("collationkey"))
} else {
assert(!queryPlan.toString().contains("collationkey"))
}
}
}
})
}
Expand Down Expand Up @@ -1914,6 +2051,26 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
case _: SortMergeJoinExec => ()
}.isEmpty
)

// Disable broadcast join to force sort merge join.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y")
checkAnswer(df, t.result)

val queryPlan = df.queryExecution.executedPlan

// confirm that sort merge join is used instead of hash join
assert(
collectFirst(queryPlan) {
case _: HashJoin => ()
}.isEmpty
)
assert(
collectFirst(queryPlan) {
case _: SortMergeJoinExec => ()
}.nonEmpty
)
}
}
})
}
Expand Down