From 81b08fc99408819034edfe4af2c40f5903adafcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Wed, 6 Nov 2024 16:03:58 +0100 Subject: [PATCH 1/4] Extended CollationSuite and added tests where SortMergeJoin is forced --- .../org/apache/spark/sql/CollationSuite.scala | 167 +++++++++++++++++- 1 file changed, 162 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index d69ba77a1475..42b772d7eba2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1549,7 +1549,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("hash join should be used for collated strings") { + test("hash join should be used for collated strings if sort merge join is not forced") { val t1 = "T_1" val t2 = "T_2" @@ -1598,11 +1598,38 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { case b: HashJoin => b.leftKeys.head }.head.isInstanceOf[CollationKey]) } + + // Disable broadcast join to force sort merge join. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) + + val queryPlan = df.queryExecution.executedPlan + + // confirm that sort merge join is used instead of hash join + assert( + collectFirst(queryPlan) { + case _: HashJoin => () + }.isEmpty + ) + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.nonEmpty + ) + + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } } }) } - test("hash join should be used for arrays of collated strings") { + test("hash join should be used for arrays of collated strings if sort merge join is not forced") { val t1 = "T_1" val t2 = "T_2" @@ -1656,11 +1683,39 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.isInstanceOf[ArrayTransform]) } + + // Disable broadcast join to force sort merge join. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) + + val queryPlan = df.queryExecution.executedPlan + + // confirm that sort merge join is used instead of hash join + assert( + collectFirst(queryPlan) { + case _: HashJoin => () + }.isEmpty + ) + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.nonEmpty + ) + + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } } }) } - test("hash join should be used for arrays of arrays of collated strings") { + test("hash join should be used for arrays of arrays of collated strings " + + "if sort merge join is not forced") { val t1 = "T_1" val t2 = "T_2" @@ -1718,11 +1773,38 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.isInstanceOf[ArrayTransform]) } + + // Disable broadcast join to force sort merge join. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) + + val queryPlan = df.queryExecution.executedPlan + + // confirm that sort merge join is used instead of hash join + assert( + collectFirst(queryPlan) { + case _: HashJoin => () + }.isEmpty + ) + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.nonEmpty + ) + + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } } }) } - test("hash join should respect collation for struct of strings") { + test("hash and sort merge join should respect collation for struct of strings") { val t1 = "T_1" val t2 = "T_2" @@ -1771,11 +1853,39 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } else { assert(!queryPlan.toString().contains("collationkey")) } + + // Disable broadcast join to force sort merge join. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) + + val queryPlan = df.queryExecution.executedPlan + + // confirm that sort merge join is used instead of hash join + assert( + collectFirst(queryPlan) { + case _: HashJoin => () + }.isEmpty + ) + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.nonEmpty + ) + + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } } }) } - test("hash join should respect collation for struct of array of struct of strings") { + test("hash and sort merge join should respect collation " + + "for struct of array of struct of strings") { val t1 = "T_1" val t2 = "T_2" @@ -1830,6 +1940,33 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } else { assert(!queryPlan.toString().contains("collationkey")) } + + // Disable broadcast join to force sort merge join. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) + + val queryPlan = df.queryExecution.executedPlan + + // confirm that sort merge join is used instead of hash join + assert( + collectFirst(queryPlan) { + case _: HashJoin => () + }.isEmpty + ) + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.nonEmpty + ) + + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } } }) } @@ -1914,6 +2051,26 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { case _: SortMergeJoinExec => () }.isEmpty ) + + // Disable broadcast join to force sort merge join. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y") + checkAnswer(df, t.result) + + val queryPlan = df.queryExecution.executedPlan + + // confirm that sort merge join is used instead of hash join + assert( + collectFirst(queryPlan) { + case _: HashJoin => () + }.isEmpty + ) + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.nonEmpty + ) + } } }) } From 265efcd44e6b7fca1a674c4e706f2344a516393e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Thu, 7 Nov 2024 14:00:16 +0100 Subject: [PATCH 2/4] Refactored test suite to reduce code --- .../org/apache/spark/sql/CollationSuite.scala | 465 +++++++----------- 1 file changed, 175 insertions(+), 290 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 42b772d7eba2..2260460b07c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.errors.DataTypeErrors.toSQLType +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec @@ -43,6 +44,39 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { private val collationNonPreservingSources = Seq("orc", "csv", "json", "text") private val allFileBasedDataSources = collationPreservingSources ++ collationNonPreservingSources + @inline + private def isSortMergeForced: Boolean = { + SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD) == -1 + } + + private def checkRightTypeOfJoinUsed(queryPlan: SparkPlan): Unit = { + assert( + // If sort merge join is forced, we should not see HashJoin in the plan. + isSortMergeForced || + // If sort merge join is not forced, we should see HashJoin in the plan + // and not SortMergeJoin. + collectFirst(queryPlan) { + case _: HashJoin => () + }.nonEmpty && + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.isEmpty + ) + + assert( + // If sort merge join is not forced, we should not see SortMergeJoin in the plan. + !isSortMergeForced || + // If sort merge join is forced, we should see SortMergeJoin in the plan + // and not HashJoin. + collectFirst(queryPlan) { + case _: HashJoin => () + }.isEmpty && + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.nonEmpty + ) + } + test("collate returns proper type") { Seq( "utf8_binary", @@ -1562,71 +1596,49 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", Seq(Row("aa", 1, "AA ", 2), Row("aa", 1, "aa", 2))) ) - - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)") - - sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)") - - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)") - val queryPlan = df.queryExecution.executedPlan + sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)") - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.isEmpty - ) - - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(collectFirst(queryPlan) { - case b: HashJoin => b.leftKeys.head - }.head.isInstanceOf[CollationKey]) - } else { - assert(!collectFirst(queryPlan) { - case b: HashJoin => b.leftKeys.head - }.head.isInstanceOf[CollationKey]) - } - - // Disable broadcast join to force sort merge join. - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") checkAnswer(df, t.result) val queryPlan = df.queryExecution.executedPlan - // confirm that sort merge join is used instead of hash join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.isEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.nonEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) + if (isSortMergeForced) { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } + else { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(collectFirst(queryPlan) { + case b: HashJoin => b.leftKeys.head + }.head.isInstanceOf[CollationKey]) + } else { + assert(!collectFirst(queryPlan) { + case b: HashJoin => b.leftKeys.head + }.head.isInstanceOf[CollationKey]) + } } } } - }) + } } test("hash join should be used for arrays of collated strings if sort merge join is not forced") { @@ -1647,71 +1659,50 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row(Seq("aa"), 1, Seq("AA "), 2), Row(Seq("aa"), 1, Seq("aa"), 2))) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x ARRAY, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (array('${t.data1}'), 1)") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x ARRAY, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (array('${t.data1}'), 1)") - sql(s"CREATE TABLE $t2 (y ARRAY, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (array('${t.data2}'), 2), (array('${t.data1}'), 2)") + sql(s"CREATE TABLE $t2 (y ARRAY, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (array('${t.data2}'), 2), (array('${t.data1}'), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) - - val queryPlan = df.queryExecution.executedPlan - - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) - - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. - function.isInstanceOf[CollationKey]) - } else { - assert(!collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.isInstanceOf[ArrayTransform]) - } - - // Disable broadcast join to force sort merge join. - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") checkAnswer(df, t.result) val queryPlan = df.queryExecution.executedPlan - // confirm that sort merge join is used instead of hash join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.isEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.nonEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) + if (isSortMergeForced) { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } + else { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. + function.isInstanceOf[CollationKey]) + } else { + assert(!collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.isInstanceOf[ArrayTransform]) + } } } } - }) + } } test("hash join should be used for arrays of arrays of collated strings " + @@ -1733,75 +1724,54 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("AA ")), 2), Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x ARRAY>, i int) USING " + - s"PARQUET") - sql(s"INSERT INTO $t1 VALUES (array(array('${t.data1}')), 1)") - - sql(s"CREATE TABLE $t2 (y ARRAY>, j int) USING " + - s"PARQUET") - sql(s"INSERT INTO $t2 VALUES (array(array('${t.data2}')), 2)," + - s" (array(array('${t.data1}')), 2)") - - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) - - val queryPlan = df.queryExecution.executedPlan - - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x ARRAY>, i int) USING " + + s"PARQUET") + sql(s"INSERT INTO $t1 VALUES (array(array('${t.data1}')), 1)") - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.asInstanceOf[ArrayTransform].function. - asInstanceOf[LambdaFunction].function.asInstanceOf[ArrayTransform].function. - asInstanceOf[LambdaFunction].function.isInstanceOf[CollationKey]) - } else { - assert(!collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.isInstanceOf[ArrayTransform]) - } + sql(s"CREATE TABLE $t2 (y ARRAY>, j int) USING " + + s"PARQUET") + sql(s"INSERT INTO $t2 VALUES (array(array('${t.data2}')), 2)," + + s" (array(array('${t.data1}')), 2)") - // Disable broadcast join to force sort merge join. - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") checkAnswer(df, t.result) val queryPlan = df.queryExecution.executedPlan - // confirm that sort merge join is used instead of hash join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.isEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.nonEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) + if (isSortMergeForced) { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } + else { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + assert(collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.asInstanceOf[ArrayTransform].function. + asInstanceOf[LambdaFunction].function.asInstanceOf[ArrayTransform].function. + asInstanceOf[LambdaFunction].function.isInstanceOf[CollationKey]) + } else { + assert(!collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.isInstanceOf[ArrayTransform]) + } } } } - }) + } } test("hash and sort merge join should respect collation for struct of strings") { @@ -1821,57 +1791,26 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", Seq(Row(Row("aa"), 1, Row("AA "), 2), Row(Row("aa"), 1, Row("aa"), 2))) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x STRUCT, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (named_struct('f', '${t.data1}'), 1)") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x STRUCT, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (named_struct('f', '${t.data1}'), 1)") - sql(s"CREATE TABLE $t2 (y STRUCT, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (named_struct('f', '${t.data2}'), 2)," + - s" (named_struct('f', '${t.data1}'), 2)") + sql(s"CREATE TABLE $t2 (y STRUCT, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (named_struct('f', '${t.data2}'), 2)," + + s" (named_struct('f', '${t.data1}'), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) - - val queryPlan = df.queryExecution.executedPlan - - // Confirm that hash join is used instead of sort merge join. - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) - - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) - } - - // Disable broadcast join to force sort merge join. - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") checkAnswer(df, t.result) val queryPlan = df.queryExecution.executedPlan - // confirm that sort merge join is used instead of hash join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.isEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.nonEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) // Only if collation doesn't support binary equality, collation key should be injected. if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { @@ -1881,7 +1820,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } } - }) + } } test("hash and sort merge join should respect collation " + @@ -1905,60 +1844,30 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("AA "))), 2), Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))) ) - testCases.foreach(t => { - withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x STRUCT>>, " + - s"i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (named_struct('f', array(named_struct('f', '${t.data1}'))), 1)" - ) - - sql(s"CREATE TABLE $t2 (y STRUCT>>, " + - s"j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (named_struct('f', array(named_struct('f', '${t.data2}'))), 2)" - + s", (named_struct('f', array(named_struct('f', '${t.data1}'))), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) - - val queryPlan = df.queryExecution.executedPlan - - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { + withTable(t1, t2) { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x STRUCT>>, " + + s"i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (named_struct('f', array(named_struct('f', " + + s"'${t.data1}'))), 1)") - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) - } + sql(s"CREATE TABLE $t2 (y STRUCT>>, " + + s"j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (named_struct('f', array(named_struct('f', " + + s"'${t.data2}'))), 2), (named_struct('f', array(named_struct('f', '${t.data1}'))), 2)") - // Disable broadcast join to force sort merge join. - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") checkAnswer(df, t.result) val queryPlan = df.queryExecution.executedPlan - // confirm that sort merge join is used instead of hash join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.isEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.nonEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) // Only if collation doesn't support binary equality, collation key should be injected. if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { @@ -1968,7 +1877,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } } - }) + } } test("rewrite with collationkey should be an excludable rule") { @@ -2028,51 +1937,27 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "'a', 'a', 1", "'A', 'A ', 1", Row("a", "a", 1, "A", "A ", 1)) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (${t.data1})") - sql(s"CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (${t.data2})") - - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y") - checkAnswer(df, t.result) - - val queryPlan = df.queryExecution.executedPlan + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (${t.data1})") + sql(s"CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (${t.data2})") - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.isEmpty - ) - - // Disable broadcast join to force sort merge join. - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y") checkAnswer(df, t.result) val queryPlan = df.queryExecution.executedPlan - // confirm that sort merge join is used instead of hash join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.isEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.nonEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) } } - }) + } } test("hll sketch aggregate should respect collation") { From 9851790db3e24a9ef44e99dfebb1b1955ffa6a2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Fri, 8 Nov 2024 11:26:12 +0100 Subject: [PATCH 3/4] Simplified assertions for SortMerge and HashJoin --- .../org/apache/spark/sql/CollationSuite.scala | 45 ++++++------------- 1 file changed, 13 insertions(+), 32 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 2260460b07c3..83d5a452c937 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -51,29 +51,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { private def checkRightTypeOfJoinUsed(queryPlan: SparkPlan): Unit = { assert( - // If sort merge join is forced, we should not see HashJoin in the plan. - isSortMergeForced || - // If sort merge join is not forced, we should see HashJoin in the plan - // and not SortMergeJoin. - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty && - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.isEmpty - ) - - assert( - // If sort merge join is not forced, we should not see SortMergeJoin in the plan. - !isSortMergeForced || - // If sort merge join is forced, we should see SortMergeJoin in the plan - // and not HashJoin. - collectFirst(queryPlan) { - case _: HashJoin => () - }.isEmpty && - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.nonEmpty + collectFirst(queryPlan) { + case _: SortMergeJoinExec => assert(isSortMergeForced) + case _: HashJoin => assert(!isSortMergeForced) + }.nonEmpty ) } @@ -1413,7 +1394,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) { val collationSetup = if (collation.isEmpty) "" else " COLLATE " + collation val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE" || - CollationFactory.fetchCollation(collation).isUtf8BinaryType + CollationFactory.fetchCollation(collation).supportsBinaryEquality test(s"Group by on map containing$collationSetup strings ($codeGen)") { val tableName = "t" @@ -1618,7 +1599,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { if (isSortMergeForced) { // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) @@ -1626,7 +1607,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } else { // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { assert(collectFirst(queryPlan) { case b: HashJoin => b.leftKeys.head }.head.isInstanceOf[CollationKey]) @@ -1681,7 +1662,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { if (isSortMergeForced) { // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) @@ -1689,7 +1670,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } else { // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. @@ -1749,7 +1730,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { if (isSortMergeForced) { // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) @@ -1757,7 +1738,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } else { // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function. @@ -1813,7 +1794,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkRightTypeOfJoinUsed(queryPlan) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) @@ -1870,7 +1851,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkRightTypeOfJoinUsed(queryPlan) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) From ba3e7024fc898478808af0e0237a103c2a274cd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Tue, 12 Nov 2024 12:23:43 +0100 Subject: [PATCH 4/4] Moved collation key check into common helper function --- .../org/apache/spark/sql/CollationSuite.scala | 49 +++++++------------ 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 83d5a452c937..5024bd19cb13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -58,6 +58,15 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) } + private def checkCollationKeyInQueryPlan(queryPlan: SparkPlan, collationName: String): Unit = { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(collationName).supportsBinaryEquality) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } + test("collate returns proper type") { Seq( "utf8_binary", @@ -1598,12 +1607,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkRightTypeOfJoinUsed(queryPlan) if (isSortMergeForced) { - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) - } + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) } else { // Only if collation doesn't support binary equality, collation key should be injected. @@ -1661,12 +1666,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkRightTypeOfJoinUsed(queryPlan) if (isSortMergeForced) { - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) - } + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) } else { // Only if collation doesn't support binary equality, collation key should be injected. @@ -1729,12 +1730,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkRightTypeOfJoinUsed(queryPlan) if (isSortMergeForced) { - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) - } + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) } else { // Only if collation doesn't support binary equality, collation key should be injected. @@ -1793,12 +1790,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // confirm that right kind of join is used. checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) - } + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) } } } @@ -1850,12 +1843,8 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // confirm that right kind of join is used. checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) - } + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) } } }