Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address all comments and rebase to latest master
  • Loading branch information
c21 committed Jul 2, 2021
commit 9fd1bbe693fa2e05c154e4b960e6ebd8ace4fc90
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val ADAPTIVE_COST_EVALUATOR_CLASS =
buildConf("spark.sql.adaptive.costEvaluatorClass")
val ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS =
buildConf("spark.sql.adaptive.customCostEvaluatorClass")
.doc("The custom cost evaluator class to be used for adaptive execution. If not being set," +
" Spark will use its own SimpleCostEvaluator by default.")
.version("3.2.0")
Copy link
Member

@HyukjinKwon HyukjinKwon Jul 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only think is that the version has to be 3.3.0 since we cut the branch now. Since this PR won't likely affect anything in the main code, I am okay with merging to 3.2.0 either tho. I will leave it to @cloud-fan and you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.2 is the first version that enables AQE by default, and this seems to be a useful extension. Let's include it in 3.2.

.internal()
.stringConf
.createWithDefault("org.apache.spark.sql.execution.adaptive.SimpleCostEvaluator")
.createOptional

val SUBEXPRESSION_ELIMINATION_ENABLED =
buildConf("spark.sql.subexpressionElimination.enabled")
Expand Down Expand Up @@ -3589,7 +3591,8 @@ class SQLConf extends Serializable with Logging {

def coalesceShufflePartitionsEnabled: Boolean = getConf(COALESCE_PARTITIONS_ENABLED)

def adaptiveCostEvaluatorClass: String = getConf(ADAPTIVE_COST_EVALUATOR_CLASS)
def adaptiveCustomCostEvaluatorClass: Option[String] =
getConf(ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS)

def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ case class AdaptiveSparkPlanExec(
}
}

@transient private val costEvaluator = CostEvaluator.instantiate(conf.adaptiveCostEvaluatorClass)
@transient private val costEvaluator = conf.adaptiveCustomCostEvaluatorClass match {
case Some(className) => CostEvaluator.instantiate(className, session.sparkContext.getConf)
case _ => SimpleCostEvaluator
}

@transient val initialPlan = context.session.withActive {
applyPhysicalRules(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils

/**
Expand All @@ -38,11 +40,12 @@ object CostEvaluator extends Logging {
/**
* Instantiates a [[CostEvaluator]] using the given className.
*/
def instantiate(className: String): CostEvaluator = {
def instantiate(className: String, conf: SparkConf): CostEvaluator = {
logDebug(s"Creating CostEvaluator $className")
val clazz = Utils.classForName[CostEvaluator](className)
// Use the default no-argument constructor.
val ctor = clazz.getDeclaredConstructor()
ctor.newInstance()
val evaluators = Utils.loadExtensions(classOf[CostEvaluator], Seq(className), conf)
require(evaluators.nonEmpty, "A valid AQE cost evaluator must be specified by config " +
s"${SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key}, but $className resulted in zero " +
"valid evaluator.")
evaluators.head
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ case class SimpleCost(value: Long) extends Cost {
* A simple implementation of [[CostEvaluator]], which counts the number of
* [[ShuffleExchangeLike]] nodes in the plan.
*/
case class SimpleCostEvaluator() extends CostEvaluator {
object SimpleCostEvaluator extends CostEvaluator {

override def evaluateCost(plan: SparkPlan): Cost = {
val cost = plan.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1900,18 +1900,19 @@ class AdaptiveQueryExecSuite
}

test("SPARK-35794: Allow custom plugin for cost evaluator") {
CostEvaluator.instantiate(classOf[SimpleShuffleSortCostEvaluator].getCanonicalName)
CostEvaluator.instantiate(classOf[SimpleCostEvaluator].getCanonicalName)
intercept[ClassCastException] {
CostEvaluator.instantiate(classOf[InvalidCostEvaluator].getCanonicalName)
CostEvaluator.instantiate(
classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf)
intercept[IllegalArgumentException] {
CostEvaluator.instantiate(
classOf[InvalidCostEvaluator].getCanonicalName, spark.sparkContext.getConf)
}

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val query = "SELECT * FROM testData join testData2 ON key = a where value = '1'"

withSQLConf(SQLConf.ADAPTIVE_COST_EVALUATOR_CLASS.key ->
withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key ->
"org.apache.spark.sql.execution.adaptive.SimpleShuffleSortCostEvaluator") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this custom cost evaluator change the query plan? It seems to be the same with the builtin cost evaluator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - this evaluator does not change plan, and to be the same with the builtin evaluator for this query. Do we want to come up a different one here? I think this just validates the custom evaluator works.

Copy link
Contributor

@cloud-fan cloud-fan Jul 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, let's leave it then

val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
val smj = findTopLevelSortMergeJoin(plan)
Expand All @@ -1921,9 +1922,9 @@ class AdaptiveQueryExecSuite
checkNumLocalShuffleReaders(adaptivePlan)
}

withSQLConf(SQLConf.ADAPTIVE_COST_EVALUATOR_CLASS.key ->
withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key ->
"org.apache.spark.sql.execution.adaptive.InvalidCostEvaluator") {
intercept[ClassCastException] {
intercept[IllegalArgumentException] {
runAdaptiveAndVerifyResult(query)
}
}
Expand Down