diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 59a310d1e4f5..0f4d1ae98220 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -678,6 +678,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS = + buildConf("spark.sql.adaptive.customCostEvaluatorClass") + .doc("The custom cost evaluator class to be used for adaptive execution. If not being set," + + " Spark will use its own SimpleCostEvaluator by default.") + .version("3.2.0") + .stringConf + .createOptional + val SUBEXPRESSION_ELIMINATION_ENABLED = buildConf("spark.sql.subexpressionElimination.enabled") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index cbf70e37ce96..18aaf5b669ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -130,7 +130,11 @@ case class AdaptiveSparkPlanExec( } } - @transient private val costEvaluator = SimpleCostEvaluator + @transient private val costEvaluator = + conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match { + case Some(className) => CostEvaluator.instantiate(className, session.sparkContext.getConf) + case _ => SimpleCostEvaluator + } @transient val initialPlan = context.session.withActive { applyPhysicalRules( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/costing.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/costing.scala index 293e619fffb9..56f29b966d96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/costing.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/costing.scala @@ -17,16 +17,42 @@ package org.apache.spark.sql.execution.adaptive +import org.apache.spark.SparkConf +import org.apache.spark.annotation.Unstable +import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils /** - * Represents the cost of a plan. + * An interface to represent the cost of a plan. + * + * @note This class is subject to be changed and/or moved in the near future. */ +@Unstable trait Cost extends Ordered[Cost] /** - * Evaluates the cost of a physical plan. + * An interface to evaluate the cost of a physical plan. + * + * @note This class is subject to be changed and/or moved in the near future. */ +@Unstable trait CostEvaluator { def evaluateCost(plan: SparkPlan): Cost } + +object CostEvaluator extends Logging { + + /** + * Instantiates a [[CostEvaluator]] using the given className. + */ + def instantiate(className: String, conf: SparkConf): CostEvaluator = { + logDebug(s"Creating CostEvaluator $className") + val evaluators = Utils.loadExtensions(classOf[CostEvaluator], Seq(className), conf) + require(evaluators.nonEmpty, "A valid AQE cost evaluator must be specified by config " + + s"${SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key}, but $className resulted in zero " + + "valid evaluator.") + evaluators.head + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index dac718e7d602..b46cc9f42746 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1898,4 +1898,54 @@ class AdaptiveQueryExecSuite assert(coalesceReader.head.partitionSpecs.length == 1) } } + + test("SPARK-35794: Allow custom plugin for cost evaluator") { + CostEvaluator.instantiate( + classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf) + intercept[IllegalArgumentException] { + CostEvaluator.instantiate( + classOf[InvalidCostEvaluator].getCanonicalName, spark.sparkContext.getConf) + } + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val query = "SELECT * FROM testData join testData2 ON key = a where value = '1'" + + withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key -> + "org.apache.spark.sql.execution.adaptive.SimpleShuffleSortCostEvaluator") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + checkNumLocalShuffleReaders(adaptivePlan) + } + + withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key -> + "org.apache.spark.sql.execution.adaptive.InvalidCostEvaluator") { + intercept[IllegalArgumentException] { + runAdaptiveAndVerifyResult(query) + } + } + } + } +} + +/** + * Invalid implementation class for [[CostEvaluator]]. + */ +private class InvalidCostEvaluator() {} + +/** + * A simple [[CostEvaluator]] to count number of [[ShuffleExchangeLike]] and [[SortExec]]. + */ +private case class SimpleShuffleSortCostEvaluator() extends CostEvaluator { + override def evaluateCost(plan: SparkPlan): Cost = { + val cost = plan.collect { + case s: ShuffleExchangeLike => s + case s: SortExec => s + }.size + SimpleCost(cost) + } }