diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index b0fae3fd9443..877ca7f4a9cb 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -487,6 +487,7 @@ private[spark] object LogKeys { case object NUM_DRIVERS extends LogKey case object NUM_DROPPED_PARTITIONS extends LogKey case object NUM_EFFECTIVE_RULE_OF_RUNS extends LogKey + case object NUM_ELEMENTS_SPILL_RECORDS extends LogKey case object NUM_ELEMENTS_SPILL_THRESHOLD extends LogKey case object NUM_EVENTS extends LogKey case object NUM_EXAMPLES extends LogKey @@ -768,6 +769,8 @@ private[spark] object LogKeys { case object SPARK_REPO_URL extends LogKey case object SPARK_REVISION extends LogKey case object SPARK_VERSION extends LogKey + case object SPILL_RECORDS_SIZE extends LogKey + case object SPILL_RECORDS_SIZE_THRESHOLD extends LogKey case object SPILL_TIMES extends LogKey case object SQL_TEXT extends LogKey case object SRC_PATH extends LogKey diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index de3c41a4b526..7502df9e16a8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -89,6 +89,11 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck */ private final int numElementsForSpillThreshold; + /** + * Force this sorter to spill when the size in memory is beyond this threshold. + */ + private final long recordsSizeForSpillThreshold; + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -112,6 +117,7 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; private long pageCursor = -1; + private long inMemRecordsSize = 0; // Checksum calculator for each partition. Empty when shuffle checksum disabled. private final Checksum[] partitionChecksums; @@ -136,6 +142,8 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.numElementsForSpillThreshold = (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); + this.recordsSizeForSpillThreshold = + (long) conf.get(package$.MODULE$.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD()); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( this, initialSize, (boolean) conf.get(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT())); @@ -338,6 +346,7 @@ private long freeMemory() { allocatedPages.clear(); currentPage = null; pageCursor = 0; + inMemRecordsSize = 0; return memoryFreed; } @@ -417,12 +426,17 @@ private void acquireNewPageIfNecessary(int required) { public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) throws IOException { - // for tests assert(inMemSorter != null); if (inMemSorter.numRecords() >= numElementsForSpillThreshold) { - logger.info("Spilling data because number of spilledRecords crossed the threshold {}" + + logger.info("Spilling data because number of spilledRecords ({}) crossed the threshold {}", + MDC.of(LogKeys.NUM_ELEMENTS_SPILL_RECORDS$.MODULE$, inMemSorter.numRecords()), MDC.of(LogKeys.NUM_ELEMENTS_SPILL_THRESHOLD$.MODULE$, numElementsForSpillThreshold)); spill(); + } else if (inMemRecordsSize >= recordsSizeForSpillThreshold) { + logger.info("Spilling data because size of spilledRecords ({}) crossed the size threshold {}", + MDC.of(LogKeys.SPILL_RECORDS_SIZE$.MODULE$, inMemRecordsSize), + MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD$.MODULE$, recordsSizeForSpillThreshold)); + spill(); } growPointerArrayIfNecessary(); @@ -439,6 +453,7 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, partitionId); + inMemRecordsSize += required; } /** diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index af421e903ba3..b99ac3079c56 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -80,6 +80,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer { */ private final int numElementsForSpillThreshold; + /** + * Force this sorter to spill when the size in memory is beyond this threshold. + */ + private final long recordsSizeForSpillThreshold; + /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -92,6 +97,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { // These variables are reset after spilling: @Nullable private volatile UnsafeInMemorySorter inMemSorter; + private long inMemRecordsSize = 0; private MemoryBlock currentPage = null; private long pageCursor = -1; @@ -110,11 +116,13 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, + long recordsSizeForSpillThreshold, UnsafeInMemorySorter inMemorySorter, long existingMemoryConsumption) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, - pageSizeBytes, numElementsForSpillThreshold, inMemorySorter, false /* ignored */); + pageSizeBytes, numElementsForSpillThreshold, recordsSizeForSpillThreshold, + inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); taskContext.taskMetrics().incMemoryBytesSpilled(existingMemoryConsumption); sorter.totalSpillBytes += existingMemoryConsumption; @@ -133,10 +141,11 @@ public static UnsafeExternalSorter create( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, + long recordsSizeForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes, - numElementsForSpillThreshold, null, canUseRadixSort); + numElementsForSpillThreshold, recordsSizeForSpillThreshold, null, canUseRadixSort); } private UnsafeExternalSorter( @@ -149,6 +158,7 @@ private UnsafeExternalSorter( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, + long recordsSizeForSpillThreshold, @Nullable UnsafeInMemorySorter existingInMemorySorter, boolean canUseRadixSort) { super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); @@ -178,6 +188,7 @@ private UnsafeExternalSorter( this.inMemSorter = existingInMemorySorter; } this.peakMemoryUsedBytes = getMemoryUsage(); + this.recordsSizeForSpillThreshold = recordsSizeForSpillThreshold; this.numElementsForSpillThreshold = numElementsForSpillThreshold; // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at @@ -238,6 +249,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { // pages will currently be counted as memory spilled even though that space isn't actually // written to disk. This also counts the space needed to store the sorter's pointer array. inMemSorter.freeMemory(); + inMemRecordsSize = 0; // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the // records. Otherwise, if the task is over allocated memory, then without freeing the memory // pages, we might not be able to get memory for the pointer array. @@ -480,9 +492,15 @@ public void insertRecord( assert(inMemSorter != null); if (inMemSorter.numRecords() >= numElementsForSpillThreshold) { - logger.info("Spilling data because number of spilledRecords crossed the threshold {}", + logger.info("Spilling data because number of spilledRecords ({}) crossed the threshold {}", + MDC.of(LogKeys.NUM_ELEMENTS_SPILL_RECORDS$.MODULE$, inMemSorter.numRecords()), MDC.of(LogKeys.NUM_ELEMENTS_SPILL_THRESHOLD$.MODULE$, numElementsForSpillThreshold)); spill(); + } else if (inMemRecordsSize >= recordsSizeForSpillThreshold) { + logger.info("Spilling data because size of spilledRecords ({}) crossed the size threshold {}", + MDC.of(LogKeys.SPILL_RECORDS_SIZE$.MODULE$, inMemRecordsSize), + MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD$.MODULE$, recordsSizeForSpillThreshold)); + spill(); } final int uaoSize = UnsafeAlignedOffset.getUaoSize(); @@ -497,6 +515,7 @@ public void insertRecord( Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); + inMemRecordsSize += required; } /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ca06cb5ba764..9e2eb4e0b56a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1596,6 +1596,18 @@ package object config { .intConf .createWithDefault(Integer.MAX_VALUE) + private[spark] val SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD = + ConfigBuilder("spark.shuffle.spill.maxRecordsSizeForSpillThreshold") + .internal() + .doc("The maximum size in memory before forcing the shuffle sorter to spill. " + + "By default it is Long.MAX_VALUE, which means we never force the sorter to spill, " + + "until we reach some limitations, like the max page size limitation for the pointer " + + "array in the sorter.") + .version("4.1.0") + .bytesConf(ByteUnit.BYTE) + .checkValue(v => v > 0, "The threshold should be positive.") + .createWithDefault(Long.MaxValue) + private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD = ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold") .internal() diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 7f2a1a8419a7..3ee4c9c0b401 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -58,6 +58,10 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) private[this] val numElementsForceSpillThreshold: Int = SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) + // Force this collection to spill when its size is greater than this threshold + private[this] val maxSizeForceSpillThreshold: Long = + SparkEnv.get.conf.get(SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD) + // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 @volatile private[this] var myMemoryThreshold = initialMemoryThreshold @@ -80,21 +84,25 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) * @return true if `collection` was spilled to disk; false otherwise */ protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { - var shouldSpill = false - if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { + val shouldSpill = if (_elementsRead > numElementsForceSpillThreshold + || currentMemory > maxSizeForceSpillThreshold) { + // Check number of elements or memory usage limits, whichever is hit first + true + } else if (_elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold val granted = acquireMemory(amountToRequest) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection - shouldSpill = currentMemory >= myMemoryThreshold + currentMemory >= myMemoryThreshold + } else { + false } - shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold // Actually spill if (shouldSpill) { _spillCount += 1 - logSpillage(currentMemory) + logSpillage(currentMemory, _elementsRead) spill(collection) _elementsRead = 0 _memoryBytesSpilled += currentMemory @@ -140,12 +148,14 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) * Prints a standard log message detailing spillage. * * @param size number of bytes spilled + * @param elements number of elements read from input since last spill */ - @inline private def logSpillage(size: Long): Unit = { + @inline private def logSpillage(size: Long, elements: Int): Unit = { val threadId = Thread.currentThread().getId logInfo(log"Thread ${MDC(LogKeys.THREAD_ID, threadId)} " + log"spilling in-memory map of ${MDC(LogKeys.BYTE_SIZE, - org.apache.spark.util.Utils.bytesToString(size))} to disk " + + org.apache.spark.util.Utils.bytesToString(size))} " + + log"(elements: ${MDC(LogKeys.NUM_ELEMENTS_SPILL_RECORDS, elements)}) to disk " + log"(${MDC(LogKeys.NUM_SPILLS, _spillCount)} times so far)") } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 9e83717f5208..8ed929461e78 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -87,9 +87,13 @@ public int compare( private final long pageSizeBytes = conf.getSizeAsBytes( package$.MODULE$.BUFFER_PAGESIZE().key(), "4m"); - private final int spillThreshold = + private final int spillElementsThreshold = (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); + private final long spillSizeThreshold = + (long) conf.get(package$.MODULE$.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD()); + + @BeforeEach public void setUp() throws Exception { MockitoAnnotations.openMocks(this).close(); @@ -163,7 +167,8 @@ private UnsafeExternalSorter newSorter() throws IOException { prefixComparator, /* initialSize */ 1024, pageSizeBytes, - spillThreshold, + spillElementsThreshold, + spillSizeThreshold, shouldUseRadixSort()); } @@ -453,7 +458,8 @@ public void forcedSpillingWithoutComparator() throws Exception { null, /* initialSize */ 1024, pageSizeBytes, - spillThreshold, + spillElementsThreshold, + spillSizeThreshold, shouldUseRadixSort()); long[] record = new long[100]; int recordSize = record.length * 8; @@ -515,7 +521,8 @@ public void testPeakMemoryUsed() throws Exception { prefixComparator, 1024, pageSizeBytes, - spillThreshold, + spillElementsThreshold, + spillSizeThreshold, shouldUseRadixSort()); // Peak memory should be monotonically increasing. More specifically, every time diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c921f9d9c08b..a3179ea16e44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3356,6 +3356,13 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + val WINDOW_EXEC_BUFFER_SIZE_SPILL_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.spill.size.threshold") + .internal() + .doc("Threshold for size of rows to be spilled by window operator") + .version("4.1.0") + .fallbackConf(SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD) + val WINDOW_GROUP_LIMIT_THRESHOLD = buildConf("spark.sql.optimizer.windowGroupLimitThreshold") .internal() @@ -3377,6 +3384,15 @@ object SQLConf { .intConf .createWithDefault(4096) + val SESSION_WINDOW_BUFFER_SPILL_SIZE_THRESHOLD = + buildConf("spark.sql.sessionWindow.buffer.spill.size.threshold") + .internal() + .doc("Threshold for size of rows to be spilled by window operator. Note that " + + "the buffer is used only for the query Spark cannot apply aggregations on determining " + + "session window.") + .version("4.1.0") + .fallbackConf(SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD) + val SESSION_WINDOW_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.sessionWindow.buffer.spill.threshold") .internal() @@ -3420,6 +3436,13 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + val SORT_MERGE_JOIN_EXEC_BUFFER_SIZE_SPILL_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.spill.size.threshold") + .internal() + .doc("Threshold for size of rows to be spilled by sort merge join operator") + .version("4.1.0") + .fallbackConf(SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD) + val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold") .internal() @@ -3437,6 +3460,13 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + val CARTESIAN_PRODUCT_EXEC_BUFFER_SIZE_SPILL_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.spill.size.threshold") + .internal() + .doc("Threshold for size of rows to be spilled by cartesian product operator") + .version("4.1.0") + .fallbackConf(SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD) + val SUPPORT_QUOTED_REGEX_COLUMN_NAME = buildConf("spark.sql.parser.quotedRegexColumnNames") .doc("When true, quoted Identifiers (using backticks) in SELECT statement are interpreted" + " as regular expressions.") @@ -6679,24 +6709,35 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + def windowExecBufferSpillSizeThreshold: Long = getConf(WINDOW_EXEC_BUFFER_SIZE_SPILL_THRESHOLD) + def windowGroupLimitThreshold: Int = getConf(WINDOW_GROUP_LIMIT_THRESHOLD) def sessionWindowBufferInMemoryThreshold: Int = getConf(SESSION_WINDOW_BUFFER_IN_MEMORY_THRESHOLD) def sessionWindowBufferSpillThreshold: Int = getConf(SESSION_WINDOW_BUFFER_SPILL_THRESHOLD) + def sessionWindowBufferSpillSizeThreshold: Long = + getConf(SESSION_WINDOW_BUFFER_SPILL_SIZE_THRESHOLD) + def sortMergeJoinExecBufferInMemoryThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD) def sortMergeJoinExecBufferSpillThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + def sortMergeJoinExecBufferSpillSizeThreshold: Long = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SIZE_SPILL_THRESHOLD) + def cartesianProductExecBufferInMemoryThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD) def cartesianProductExecBufferSpillThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) + def cartesianProductExecBufferSizeSpillThreshold: Long = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SIZE_SPILL_THRESHOLD) + def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC) def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 90b55a8586de..66637ac2bbd9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -120,6 +120,8 @@ private UnsafeExternalRowSorter( pageSizeBytes, (int) SparkEnv.get().conf().get( package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), + (long) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD()), canUseRadixSort ); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 8587d9290078..af8d5a4610f6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -242,6 +242,8 @@ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOExcepti map.getPageSizeBytes(), (int) SparkEnv.get().conf().get( package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), + (long) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD()), map); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 6f2d12e6b790..6affcb61b8d6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -60,9 +60,10 @@ public UnsafeKVExternalSorter( BlockManager blockManager, SerializerManager serializerManager, long pageSizeBytes, - int numElementsForSpillThreshold) throws IOException { + int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold) throws IOException { this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, - numElementsForSpillThreshold, null); + numElementsForSpillThreshold, maxRecordsSizeForSpillThreshold, null); } public UnsafeKVExternalSorter( @@ -72,6 +73,7 @@ public UnsafeKVExternalSorter( SerializerManager serializerManager, long pageSizeBytes, int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; @@ -98,6 +100,7 @@ public UnsafeKVExternalSorter( (int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), pageSizeBytes, numElementsForSpillThreshold, + maxRecordsSizeForSpillThreshold, canUseRadixSort); } else { // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow @@ -165,6 +168,7 @@ public UnsafeKVExternalSorter( (int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), pageSizeBytes, numElementsForSpillThreshold, + maxRecordsSizeForSpillThreshold, inMemSorter, map.getTotalMemoryConsumption()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index 59810adc4b22..ebf974a8a480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -52,9 +52,12 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( initialSize: Int, pageSizeBytes: Long, numRowsInMemoryBufferThreshold: Int, - numRowsSpillThreshold: Int) extends Logging { + numRowsSpillThreshold: Int, + maxSizeSpillThreshold: Long) extends Logging { - def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) = { + def this(numRowsInMemoryBufferThreshold: Int, + numRowsSpillThreshold: Int, + maxSizeSpillThreshold: Long) = { this( TaskContext.get().taskMemoryManager(), SparkEnv.get.blockManager, @@ -63,7 +66,8 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( 1024, SparkEnv.get.memoryManager.pageSizeBytes, numRowsInMemoryBufferThreshold, - numRowsSpillThreshold) + numRowsSpillThreshold, + maxSizeSpillThreshold) } private val initialSizeOfInMemoryBuffer = @@ -138,6 +142,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( initialSize, pageSizeBytes, numRowsSpillThreshold, + maxSizeSpillThreshold, false) // populate with existing in-memory buffered rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index a4a6dc8e4ab0..b89b268dd3c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -332,6 +332,7 @@ class SortBasedAggregator( SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD), + SparkEnv.get.conf.get(config.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD), null ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala index 9b68e6f02a85..5384f939c31a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala @@ -79,6 +79,7 @@ class ObjectAggregationMap() { SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD), + SparkEnv.get.conf.get(config.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD), null ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsExec.scala index b5dfd4639d8f..c3786a5338d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsExec.scala @@ -52,10 +52,11 @@ case class UpdatingSessionsExec( override protected def doExecute(): RDD[InternalRow] = { val inMemoryThreshold = conf.sessionWindowBufferInMemoryThreshold val spillThreshold = conf.sessionWindowBufferSpillThreshold + val spillSizeThreshold = conf.sessionWindowBufferSpillSizeThreshold child.execute().mapPartitions { iter => new UpdatingSessionsIterator(iter, groupingExpression, sessionExpression, - child.output, inMemoryThreshold, spillThreshold) + child.output, inMemoryThreshold, spillThreshold, spillSizeThreshold) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala index 39b835f1f45f..64bb3717f52b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala @@ -43,7 +43,8 @@ class UpdatingSessionsIterator( sessionExpression: NamedExpression, inputSchema: Seq[Attribute], inMemoryThreshold: Int, - spillThreshold: Int) extends Iterator[InternalRow] { + spillThreshold: Int, + spillSizeThreshold: Long) extends Iterator[InternalRow] { private val groupingWithoutSession: Seq[NamedExpression] = groupingExpressions.diff(Seq(sessionExpression)) @@ -150,7 +151,8 @@ class UpdatingSessionsIterator( currentKeys = groupingKey.copy() currentSession = sessionStruct.copy() - rowsForCurrentSession = new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + rowsForCurrentSession = new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, + spillSizeThreshold) rowsForCurrentSession.add(currentRow.asInstanceOf[UnsafeRow]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 8a996bce251c..8065decb0dff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -36,11 +36,13 @@ class UnsafeCartesianRDD( left : RDD[UnsafeRow], right : RDD[UnsafeRow], inMemoryBufferThreshold: Int, - spillThreshold: Int) + spillThreshold: Int, + spillSizeThreshold: Long) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold) + val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold, + spillSizeThreshold) val partition = split.asInstanceOf[CartesianPartition] rdd2.iterator(partition.s2, context).foreach(rowArray.add) @@ -81,7 +83,8 @@ case class CartesianProductExec( leftResults, rightResults, conf.cartesianProductExecBufferInMemoryThreshold, - conf.cartesianProductExecBufferSpillThreshold) + conf.cartesianProductExecBufferSpillThreshold, + conf.cartesianProductExecBufferSizeSpillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala index 57ca135407d4..b4e52ba050b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala @@ -34,6 +34,7 @@ class SortMergeJoinEvaluatorFactory( output: Seq[Attribute], inMemoryThreshold: Int, spillThreshold: Int, + spillSizeThreshold: Long, numOutputRows: SQLMetric, spillSize: SQLMetric, onlyBufferFirstMatchedRow: Boolean) @@ -85,6 +86,7 @@ class SortMergeJoinEvaluatorFactory( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, spillSize, cleanupResources) private[this] val joinRow = new JoinedRow @@ -130,6 +132,7 @@ class SortMergeJoinEvaluatorFactory( bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, spillSize, cleanupResources) val rightNullRow = new GenericInternalRow(right.output.length) @@ -149,6 +152,7 @@ class SortMergeJoinEvaluatorFactory( bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, spillSize, cleanupResources) val leftNullRow = new GenericInternalRow(left.output.length) @@ -185,6 +189,7 @@ class SortMergeJoinEvaluatorFactory( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, spillSize, cleanupResources, onlyBufferFirstMatchedRow) @@ -222,6 +227,7 @@ class SortMergeJoinEvaluatorFactory( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, spillSize, cleanupResources, onlyBufferFirstMatchedRow) @@ -266,6 +272,7 @@ class SortMergeJoinEvaluatorFactory( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, spillSize, cleanupResources, onlyBufferFirstMatchedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 8d49b1558d68..39387ebbb7ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -103,6 +103,10 @@ case class SortMergeJoinExec( conf.sortMergeJoinExecBufferSpillThreshold } + private def getSpillSizeThreshold: Long = { + conf.sortMergeJoinExecBufferSpillSizeThreshold + } + // Flag to only buffer first matched row, to avoid buffering unnecessary rows. private val onlyBufferFirstMatchedRow = (joinType, condition) match { case (LeftExistence(_), None) => true @@ -121,6 +125,7 @@ case class SortMergeJoinExec( val numOutputRows = longMetric("numOutputRows") val spillSize = longMetric("spillSize") val spillThreshold = getSpillThreshold + val spillSizeThreshold = getSpillSizeThreshold val inMemoryThreshold = getInMemoryThreshold val evaluatorFactory = new SortMergeJoinEvaluatorFactory( leftKeys, @@ -132,6 +137,7 @@ case class SortMergeJoinExec( output, inMemoryThreshold, spillThreshold, + spillSizeThreshold, numOutputRows, spillSize, onlyBufferFirstMatchedRow @@ -222,11 +228,13 @@ case class SortMergeJoinExec( val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold + val spillSizeThreshold = getSpillSizeThreshold val inMemoryThreshold = getInMemoryThreshold // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", - v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold, ${spillSizeThreshold}L);", + forceInline = true) // Copy the streamed keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, streamedKeyVars) @@ -1044,6 +1052,7 @@ case class SortMergeJoinExec( * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by * internal buffer * @param spillThreshold Threshold for number of rows to be spilled by internal buffer + * @param spillSizeThreshold Threshold for size of rows to be spilled by internal buffer * @param eagerCleanupResources the eager cleanup function to be invoked when no join row found * @param onlyBufferFirstMatch [[bufferMatchingRows]] should buffer only the first matching row */ @@ -1055,6 +1064,7 @@ private[joins] class SortMergeJoinScanner( bufferedIter: RowIterator, inMemoryThreshold: Int, spillThreshold: Int, + spillSizeThreshold: Long, spillSize: SQLMetric, eagerCleanupResources: () => Unit, onlyBufferFirstMatch: Boolean = false) { @@ -1069,7 +1079,7 @@ private[joins] class SortMergeJoinScanner( private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, spillSizeThreshold) // At the end of the task, update the task's spill size for buffered side. TaskContext.get().addTaskCompletionListener[Unit](_ => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 26871b68dde8..a11d5af3fc3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -206,9 +206,10 @@ case class AggregateInPandasExec( case Some(sessionExpression) => val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold val spillThreshold = conf.windowExecBufferSpillThreshold + val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold new UpdatingSessionsIterator(iter, groupingWithoutSessionExpressions, sessionExpression, - child.output, inMemoryThreshold, spillThreshold) + child.output, inMemoryThreshold, spillThreshold, spillSizeThreshold) case None => iter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala index e7fc9c7391af..68f67060d308 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala @@ -148,6 +148,7 @@ class WindowInPandasEvaluatorFactory( private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold private val spillThreshold = conf.windowExecBufferSpillThreshold + private val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val largeVarTypes = conf.arrowUseLargeVarTypes @@ -286,7 +287,8 @@ class WindowInPandasEvaluatorFactory( // Manage the current partition. val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, + spillSizeThreshold) var bufferIterator: Iterator[UnsafeRow] = _ val indexRow = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala index 9ff056a27946..d59a0e9f4639 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala @@ -45,6 +45,7 @@ class WindowEvaluatorFactory( private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold private val spillThreshold = conf.windowExecBufferSpillThreshold + private val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold override def eval( partitionIndex: Int, @@ -82,7 +83,8 @@ class WindowEvaluatorFactory( // Manage the current partition. val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, + spillSizeThreshold) var bufferIterator: Iterator[UnsafeRow] = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 01e72daead44..6e9f33855715 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -1619,4 +1619,21 @@ class DataFrameWindowFunctionsSuite extends QueryTest } } } + + test("SPARK-49386: Window spill with more than the inMemoryThreshold and spillSizeThreshold") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> Int.MaxValue.toString) { + assertNotSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SIZE_SPILL_THRESHOLD.key -> "1") { + assertSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 41f2e5c9a406..59508d7fc101 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -809,7 +809,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0", SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { + testSpill() + } + } + + test("SPARK-49386: test SortMergeJoin (with spill by size threshold)") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0", + SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> Int.MaxValue.toString, + SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SIZE_SPILL_THRESHOLD.key -> "1") { + testSpill() + } + } + private def testSpill(): Unit = { assertSpilled(sparkContext, "inner join") { checkAnswer( sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"), @@ -896,7 +909,6 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan ) } } - } test("outer broadcast hash join should not throw NPE") { withTempView("v1", "v2") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index 31b002a1e245..461c899325f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -107,7 +107,8 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { for (_ <- 0L until iterations) { val array = new ExternalAppendOnlyUnsafeRowArray( ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, - numSpillThreshold) + numSpillThreshold, + Long.MaxValue) rows.foreach(x => array.add(x)) @@ -146,6 +147,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { 1024, SparkEnv.get.memoryManager.pageSizeBytes, numSpillThreshold, + Long.MaxValue, false) rows.foreach(x => @@ -170,7 +172,9 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, + numSpillThreshold, + Long.MaxValue) rows.foreach(x => array.add(x)) val iterator = array.generateIterator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index b9e7367d54df..62ea7f2f9259 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -47,7 +47,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar 1024, SparkEnv.get.memoryManager.pageSizeBytes, inMemoryThreshold, - spillThreshold) + spillThreshold, + Long.MaxValue) try f(array) finally { array.clear() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index b3370b6733d9..77ecea0d6293 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -128,7 +128,9 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, - pageSize, SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + pageSize, SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get, + SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD.defaultValue.get + ) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => @@ -226,6 +228,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession sparkContext.env.serializerManager, taskMemoryManager.pageSizeBytes(), Int.MaxValue, + Long.MaxValue, map) } finally { TaskContext.unset() @@ -250,6 +253,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession sparkContext.env.serializerManager, taskMemoryManager.pageSizeBytes(), Int.MaxValue, + Long.MaxValue, map) assert(sorter.getSpillSize === expectedSpillSize) } finally { @@ -275,6 +279,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession sparkContext.env.serializerManager, taskMemoryManager.pageSizeBytes(), Int.MaxValue, + Long.MaxValue, map1) val sorter2 = new UnsafeKVExternalSorter( schema, @@ -283,6 +288,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession sparkContext.env.serializerManager, taskMemoryManager.pageSizeBytes(), Int.MaxValue, + Long.MaxValue, map2) sorter1.merge(sorter2) assert(sorter1.getSpillSize === expectedSpillSize) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala index 187eda5d36f6..9aad453e8f56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala @@ -65,10 +65,11 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { // just copying default values to avoid bothering with SQLContext val inMemoryThreshold = 4096 val spillThreshold = Int.MaxValue + val spillSizeThreshold = Long.MaxValue test("no row") { val iterator = new UpdatingSessionsIterator(None.iterator, keysWithSessionAttributes, - sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold, spillSizeThreshold) assert(!iterator.hasNext) } @@ -77,7 +78,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val rows = List(createRow("a", 1, 100, 110, 10, 1.1)) val iterator = new UpdatingSessionsIterator(rows.iterator, keysWithSessionAttributes, - sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold, spillSizeThreshold) assert(iterator.hasNext) @@ -95,7 +96,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val rows = List(row1, row2, row3, row4) val iterator = new UpdatingSessionsIterator(rows.iterator, keysWithSessionAttributes, - sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold, spillSizeThreshold) val retRows = rows.indices.map { _ => assert(iterator.hasNext) @@ -126,7 +127,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val rowsAll = rows1 ++ rows2 val iterator = new UpdatingSessionsIterator(rowsAll.iterator, keysWithSessionAttributes, - sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold, spillSizeThreshold) val retRows1 = rows1.indices.map { _ => assert(iterator.hasNext) @@ -162,7 +163,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val rowsAll = rows1 ++ rows2 val iterator = new UpdatingSessionsIterator(rowsAll.iterator, keysWithSessionAttributes, - sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold, spillSizeThreshold) val retRows1 = rows1.indices.map { _ => assert(iterator.hasNext) @@ -207,7 +208,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4 val iterator = new UpdatingSessionsIterator(rowsAll.iterator, keysWithSessionAttributes, - sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold, spillSizeThreshold) val retRows1 = rows1.indices.map { _ => assert(iterator.hasNext) @@ -260,7 +261,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val rows = List(row1, row2, row3, row4) val iterator = new UpdatingSessionsIterator(rows.iterator, keysWithSessionAttributes, - sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold, spillSizeThreshold) // UpdatingSessionIterator can't detect error on hasNext assert(iterator.hasNext) @@ -296,7 +297,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val rows = List(row1, row2, row3) val iterator = new UpdatingSessionsIterator(rows.iterator, keysWithSessionAttributes, - sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold, spillSizeThreshold) // UpdatingSessionIterator can't detect error on hasNext assert(iterator.hasNext) @@ -339,7 +340,8 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val rows = List(row1, row2, row3, row4) val iterator = new UpdatingSessionsIterator(rows.iterator, Seq(noKeySessionAttribute), - noKeySessionAttribute, noKeyRowAttributes, inMemoryThreshold, spillThreshold) + noKeySessionAttribute, noKeyRowAttributes, inMemoryThreshold, spillThreshold, + spillSizeThreshold) val retRows = rows.indices.map { _ => assert(iterator.hasNext)