Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Allow custom plugin for AQE cost evaluator
  • Loading branch information
c21 committed Jul 2, 2021
commit 494b8bc1198214d4e0e72ec188b1e32e4ddf8e4e
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val ADAPTIVE_COST_EVALUATOR_CLASS =
buildConf("spark.sql.adaptive.costEvaluatorClass")
.version("3.2.0")
Copy link
Member

@HyukjinKwon HyukjinKwon Jul 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only think is that the version has to be 3.3.0 since we cut the branch now. Since this PR won't likely affect anything in the main code, I am okay with merging to 3.2.0 either tho. I will leave it to @cloud-fan and you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.2 is the first version that enables AQE by default, and this seems to be a useful extension. Let's include it in 3.2.

.internal()
.stringConf
.createWithDefault("org.apache.spark.sql.execution.adaptive.SimpleCostEvaluator")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make it an optional conf: spark.sql.adaptive.customCostEvaluatorClass. If not set, we use the builtin impl.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - sure, updated.


val SUBEXPRESSION_ELIMINATION_ENABLED =
buildConf("spark.sql.subexpressionElimination.enabled")
.internal()
Expand Down Expand Up @@ -3582,6 +3589,8 @@ class SQLConf extends Serializable with Logging {

def coalesceShufflePartitionsEnabled: Boolean = getConf(COALESCE_PARTITIONS_ENABLED)

def adaptiveCostEvaluatorClass: String = getConf(ADAPTIVE_COST_EVALUATOR_CLASS)

def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN)

def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ case class AdaptiveSparkPlanExec(
}
}

@transient private val costEvaluator = SimpleCostEvaluator
@transient private val costEvaluator = CostEvaluator.instantiate(conf.adaptiveCostEvaluatorClass)

@transient val initialPlan = context.session.withActive {
applyPhysicalRules(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,32 @@

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.Utils

/**
* Represents the cost of a plan.
* An interface to represent the cost of a plan.
*/
trait Cost extends Ordered[Cost]

/**
* Evaluates the cost of a physical plan.
* An interface to evaluate the cost of a physical plan.
*/
trait CostEvaluator {
def evaluateCost(plan: SparkPlan): Cost
}

object CostEvaluator extends Logging {

/**
* Instantiates a [[CostEvaluator]] using the given className.
*/
def instantiate(className: String): CostEvaluator = {
logDebug(s"Creating CostEvaluator $className")
val clazz = Utils.classForName[CostEvaluator](className)
Copy link
Contributor

@cloud-fan cloud-fan Jul 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use the standard API in Spark: Utils.loadExtensions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - good call, updated.

// Use the default no-argument constructor.
val ctor = clazz.getDeclaredConstructor()
ctor.newInstance()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ case class SimpleCost(value: Long) extends Cost {
* A simple implementation of [[CostEvaluator]], which counts the number of
* [[ShuffleExchangeLike]] nodes in the plan.
*/
object SimpleCostEvaluator extends CostEvaluator {
case class SimpleCostEvaluator() extends CostEvaluator {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - yeah, updated.


override def evaluateCost(plan: SparkPlan): Cost = {
val cost = plan.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1898,4 +1898,53 @@ class AdaptiveQueryExecSuite
assert(coalesceReader.head.partitionSpecs.length == 1)
}
}

test("SPARK-35794: Allow custom plugin for cost evaluator") {
CostEvaluator.instantiate(classOf[SimpleShuffleSortCostEvaluator].getCanonicalName)
CostEvaluator.instantiate(classOf[SimpleCostEvaluator].getCanonicalName)
intercept[ClassCastException] {
CostEvaluator.instantiate(classOf[InvalidCostEvaluator].getCanonicalName)
}

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val query = "SELECT * FROM testData join testData2 ON key = a where value = '1'"

withSQLConf(SQLConf.ADAPTIVE_COST_EVALUATOR_CLASS.key ->
"org.apache.spark.sql.execution.adaptive.SimpleShuffleSortCostEvaluator") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this custom cost evaluator change the query plan? It seems to be the same with the builtin cost evaluator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - this evaluator does not change plan, and to be the same with the builtin evaluator for this query. Do we want to come up a different one here? I think this just validates the custom evaluator works.

Copy link
Contributor

@cloud-fan cloud-fan Jul 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, let's leave it then

val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan)
}

withSQLConf(SQLConf.ADAPTIVE_COST_EVALUATOR_CLASS.key ->
"org.apache.spark.sql.execution.adaptive.InvalidCostEvaluator") {
intercept[ClassCastException] {
runAdaptiveAndVerifyResult(query)
}
}
}
}
}

/**
* Invalid implementation class for [[CostEvaluator]].
*/
private class InvalidCostEvaluator() {}

/**
* A simple [[CostEvaluator]] to count number of [[ShuffleExchangeLike]] and [[SortExec]].
*/
private case class SimpleShuffleSortCostEvaluator() extends CostEvaluator {
override def evaluateCost(plan: SparkPlan): Cost = {
val cost = plan.collect {
case s: ShuffleExchangeLike => s
case s: SortExec => s
}.size
SimpleCost(cost)
}
}