Skip to content

Commit 8e51ae5

Browse files
committed
[SPARK-22160][SQL] Allow changing sample points per partition in range shuffle exchange
1 parent d41e347 commit 8e51ae5

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

core/src/main/scala/org/apache/spark/Partitioner.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,17 @@ class HashPartitioner(partitions: Int) extends Partitioner {
108108
class RangePartitioner[K : Ordering : ClassTag, V](
109109
partitions: Int,
110110
rdd: RDD[_ <: Product2[K, V]],
111-
private var ascending: Boolean = true)
111+
private var ascending: Boolean = true,
112+
val samplePointsPerPartitionHint: Int = 20)
112113
extends Partitioner {
113114

115+
// A constructor declared in order to maintain backward compatibility for Java, when we add the
116+
// 4th constructor parameter samplePointsPerPartitionHint. See SPARK-22160.
117+
// This is added to make sure from a bytecode point of view, there is still a 3-arg ctor.
118+
def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = {
119+
this(partitions, rdd, ascending, samplePointsPerPartitionHint = 20)
120+
}
121+
114122
// We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
115123
require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")
116124

@@ -122,7 +130,8 @@ class RangePartitioner[K : Ordering : ClassTag, V](
122130
Array.empty
123131
} else {
124132
// This is the sample size we need to have roughly balanced output partitions, capped at 1M.
125-
val sampleSize = math.min(20.0 * partitions, 1e6)
133+
// Cast to double to avoid overflowing ints or longs
134+
val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6)
126135
// Assume the input partitions are roughly balanced and over-sample a little bit.
127136
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
128137
val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,14 @@ object SQLConf {
907907
.booleanConf
908908
.createWithDefault(false)
909909

910+
val RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION =
911+
buildConf("spark.sql.execution.rangeExchange.sampleSizePerPartition")
912+
.internal()
913+
.doc("Number of points to sample per partition in order to determine the range boundaries" +
914+
" for range partitioning, typically used in global sorting (without limit).")
915+
.intConf
916+
.createWithDefault(100)
917+
910918
val ARROW_EXECUTION_ENABLE =
911919
buildConf("spark.sql.execution.arrow.enabled")
912920
.internal()
@@ -1199,6 +1207,8 @@ class SQLConf extends Serializable with Logging {
11991207

12001208
def supportQuotedRegexColumnName: Boolean = getConf(SUPPORT_QUOTED_REGEX_COLUMN_NAME)
12011209

1210+
def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION)
1211+
12021212
def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE)
12031213

12041214
def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
3030
import org.apache.spark.sql.catalyst.plans.physical._
3131
import org.apache.spark.sql.execution._
3232
import org.apache.spark.sql.execution.metric.SQLMetrics
33+
import org.apache.spark.sql.internal.SQLConf
3334
import org.apache.spark.util.MutablePair
3435

3536
/**
@@ -218,7 +219,11 @@ object ShuffleExchangeExec {
218219
iter.map(row => mutablePair.update(row.copy(), null))
219220
}
220221
implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes)
221-
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
222+
new RangePartitioner(
223+
numPartitions,
224+
rddForSampling,
225+
ascending = true,
226+
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
222227
case SinglePartition =>
223228
new Partitioner {
224229
override def numPartitions: Int = 1

0 commit comments

Comments
 (0)