From 4dc6ef4a87cc4ed8325d3d9ed76a4ec7760a1cc8 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Fri, 18 Oct 2019 20:59:44 +0800 Subject: [PATCH 1/8] [SPARK-21492][SQL] Fix memory leak in SortMergeJoin --- .../apache/spark/sql/internal/SQLConf.scala | 8 +++ .../execution/UnsafeExternalRowSorter.java | 11 +++- .../apache/spark/sql/execution/SortExec.scala | 29 +++++++++-- .../spark/sql/execution/SparkPlan.scala | 9 ++++ .../execution/joins/SortMergeJoinExec.scala | 46 ++++++++++++---- .../org/apache/spark/sql/JoinSuite.scala | 52 ++++++++++++++++++- 6 files changed, 137 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4944099fcc0d..b8f8c52cf9fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1437,6 +1437,14 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + val SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES = + buildConf("spark.sql.sortMergeJoinExec.eagerCleanupResources") + .internal() + .doc("When true, the SortMergeJoinExec will trigger all upstream resources cleanup right " + + "after it finishes computing.") + .booleanConf + .createWithDefault(true) + val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") .internal() diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 863d80b5cb9c..3123f2187da8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -52,6 +52,12 @@ public final class UnsafeExternalRowSorter { private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; + // This flag makes sure the cleanupResource() has been called. After the cleanup work, + // iterator.next should always return false. Downstream operator triggers the resource + // cleanup while they found there's no need to keep the iterator any more. + // See more details in SPARK-21492. + private boolean isReleased = false; + public abstract static class PrefixComputer { public static class Prefix { @@ -157,7 +163,8 @@ public long getSortTimeNanos() { return sorter.getSortTimeNanos(); } - private void cleanupResources() { + public void cleanupResources() { + isReleased = true; sorter.cleanupResources(); } @@ -176,7 +183,7 @@ public Iterator sort() throws IOException { @Override public boolean hasNext() { - return sortedIterator.hasNext(); + return !isReleased && sortedIterator.hasNext(); } @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 0a955d6a7523..cd6c706a4d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -62,6 +62,14 @@ case class SortExec( "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) + private[sql] var rowSorter: UnsafeExternalRowSorter = _ + + /** + * This method gets invoked only once for each SortExec instance to initialize an + * UnsafeExternalRowSorter, both `plan.execute` and code generation are using it. + * In the code generation code path, we need to call this function outside the class so we + * should make it public. + */ def createSorter(): UnsafeExternalRowSorter = { val ordering = newOrdering(sortOrder, output) @@ -87,13 +95,13 @@ case class SortExec( } val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = UnsafeExternalRowSorter.create( + rowSorter = UnsafeExternalRowSorter.create( schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) + rowSorter.setTestSpillFrequency(testSpillFrequency) } - sorter + rowSorter } protected override def doExecute(): RDD[InternalRow] = { @@ -127,7 +135,7 @@ case class SortExec( // Name of sorter variable used in codegen. private var sorterVariable: String = _ - override protected def doProduce(ctx: CodegenContext): String = { + override protected[sql] def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") @@ -181,4 +189,17 @@ case class SortExec( |$sorterVariable.insertRow((UnsafeRow)${row.value}); """.stripMargin } + + /** + * In SortExec, we overwrites cleanupResources to close UnsafeExternalRowSorter. + */ + override protected[sql] def cleanupResources(): Unit = { + super.cleanupResources() + if (rowSorter != null) { + // There's possible for rowSorter is null here, for example, in the scenario of empty + // iterator in the current task, the downstream physical node(like SortMergeJoinExec) will + // trigger cleanupResources before rowSorter initialized in createSorter. + rowSorter.cleanupResources() + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b4cdf9e16b7e..125f76282e3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -507,6 +507,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } newOrdering(order, Seq.empty) } + + /** + * Cleans up the resources used by the physical operator (if any). In general, all the resources + * should be cleaned up when the task finishes but operators like SortMergeJoinExec and LimitExec + * may want eager cleanup to free up tight resources (e.g., memory). + */ + protected[sql] def cleanupResources(): Unit = { + children.foreach(_.cleanupResources()) + } } trait LeafExecNode extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 189727a9bc88..30febd71327f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.internal.SQLConf.SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES import org.apache.spark.util.collection.BitSet /** @@ -161,6 +162,10 @@ case class SortMergeJoinExec( sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold } + private def needEagerCleanup: Boolean = { + sqlContext.conf.getConf(SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES) + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold @@ -191,7 +196,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -235,7 +241,8 @@ case class SortMergeJoinExec( streamedIter = RowIterator.fromScala(leftIter), bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( @@ -249,7 +256,8 @@ case class SortMergeJoinExec( streamedIter = RowIterator.fromScala(rightIter), bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( @@ -283,7 +291,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -318,7 +327,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -360,7 +370,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -640,6 +651,11 @@ case class SortMergeJoinExec( (evaluateVariables(leftVars), "") } + val eagerCleanup = if (needEagerCleanup) { + val thisPlan = ctx.addReferenceObj("plan", this) + s"$thisPlan.cleanupResources();" + } else "" + s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { | ${leftVarDecl.mkString("\n")} @@ -653,6 +669,7 @@ case class SortMergeJoinExec( | } | if (shouldStop()) return; |} + |$eagerCleanup """.stripMargin } } @@ -678,6 +695,7 @@ case class SortMergeJoinExec( * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by * internal buffer * @param spillThreshold Threshold for number of rows to be spilled by internal buffer + * @param eagerCleanupResources the eager cleanup function to be invoked when no join row found */ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, @@ -686,7 +704,8 @@ private[joins] class SortMergeJoinScanner( streamedIter: RowIterator, bufferedIter: RowIterator, inMemoryThreshold: Int, - spillThreshold: Int) { + spillThreshold: Int, + eagerCleanupResources: () => Unit) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -710,7 +729,8 @@ private[joins] class SortMergeJoinScanner( def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches /** - * Advances both input iterators, stopping when we have found rows with matching join keys. + * Advances both input iterators, stopping when we have found rows with matching join keys. If no + * join rows found, try to do the eager resources cleanup. * @return true if matching rows have been found and false otherwise. If this returns true, then * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join * results. @@ -720,7 +740,7 @@ private[joins] class SortMergeJoinScanner( // Advance the streamed side of the join until we find the next row whose join key contains // no nulls or we hit the end of the streamed iterator. } - if (streamedRow == null) { + val found = if (streamedRow == null) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() @@ -760,17 +780,19 @@ private[joins] class SortMergeJoinScanner( true } } + if (!found) eagerCleanupResources() + found } /** * Advances the streamed input iterator and buffers all rows from the buffered input that - * have matching keys. + * have matching keys. If no join rows found, try to do the eager resources cleanup. * @return true if the streamed iterator returned a row, false otherwise. If this returns true, * then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer * join results. */ final def findNextOuterJoinRows(): Boolean = { - if (!advancedStreamed()) { + val found = if (!advancedStreamed()) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() @@ -800,6 +822,8 @@ private[joins] class SortMergeJoinScanner( // If there is a streamed input then we always return true true } + if (!found) eagerCleanupResources() + found } // --- Private methods -------------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 72742644ff34..7d2d93fa194f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,12 +22,16 @@ import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.Filter -import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec} +import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf @@ -1040,3 +1044,49 @@ class JoinSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(1, 2, 1, 2) :: Nil) } } + +class JoinWithResourceCleanSuite extends JoinSuite with BeforeAndAfterAll { + import testImplicits._ + import scala.collection.mutable.ArrayBuffer + + private def checkCleanupResourceTriggered(plan: SparkPlan) : ArrayBuffer[SortExec] = { + // Check cleanupResources are finally triggered in SortExec node + val sorts = new ArrayBuffer[SortExec]() + plan.foreachUp { + case s: SortExec => sorts += s + case _ => + } + sorts.foreach { sort => + val sortExec = spy(sort) + verify(sortExec, atLeastOnce).cleanupResources() + verify(sortExec.rowSorter, atLeastOnce).cleanupResources() + } + sorts + } + + override def checkAnswer(df: => DataFrame, rows: Seq[Row]): Unit = { + withSQLConf( + SQLConf.SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES.key -> "true") { + checkCleanupResourceTriggered(df.queryExecution.sparkPlan) + super.checkAnswer(df, rows) + } + } + + test("cleanupResource in code generation") { + withSQLConf( + SQLConf.SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 10, 1, 2) + val df2 = spark.range(10).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") + + val sorts = checkCleanupResourceTriggered(res.queryExecution.sparkPlan) + // Make sure SortExec did the code generation + sorts.foreach { sort => + verify(spy(sort), atLeastOnce).doProduce(any()) + } + checkAnswer(res, Row(0, 0, 0)) + } + } +} From 631f3cba3509b00501724c505a18d695a6e2acfb Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sat, 19 Oct 2019 11:31:52 +0800 Subject: [PATCH 2/8] Address comments --- .../apache/spark/sql/internal/SQLConf.scala | 3 +++ .../apache/spark/sql/execution/SortExec.scala | 3 ++- .../execution/joins/SortMergeJoinExec.scala | 20 +++++++++++-------- .../org/apache/spark/sql/JoinSuite.scala | 2 +- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b8f8c52cf9fc..98633f139c07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2432,6 +2432,9 @@ class SQLConf extends Serializable with Logging { def sortMergeJoinExecBufferSpillThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + def sortMergeJoinExecEagerCleanupResources: Boolean = + getConf(SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES) + def cartesianProductExecBufferInMemoryThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index cd6c706a4d13..c520495d5a1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -71,6 +71,7 @@ case class SortExec( * should make it public. */ def createSorter(): UnsafeExternalRowSorter = { + assert(rowSorter == null) val ordering = newOrdering(sortOrder, output) // The comparator for comparing prefix @@ -194,12 +195,12 @@ case class SortExec( * In SortExec, we overwrites cleanupResources to close UnsafeExternalRowSorter. */ override protected[sql] def cleanupResources(): Unit = { - super.cleanupResources() if (rowSorter != null) { // There's possible for rowSorter is null here, for example, in the scenario of empty // iterator in the current task, the downstream physical node(like SortMergeJoinExec) will // trigger cleanupResources before rowSorter initialized in createSorter. rowSorter.cleanupResources() } + super.cleanupResources() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 30febd71327f..ad9f25ad5645 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.internal.SQLConf.SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES import org.apache.spark.util.collection.BitSet /** @@ -163,13 +162,18 @@ case class SortMergeJoinExec( } private def needEagerCleanup: Boolean = { - sqlContext.conf.getConf(SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES) + sqlContext.conf.sortMergeJoinExecEagerCleanupResources } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold + val cleanupResourceFunc: () => Unit = if (needEagerCleanup) { + cleanupResources + } else { + () => {} + } left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -197,7 +201,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResourceFunc ) private[this] val joinRow = new JoinedRow @@ -242,7 +246,7 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResourceFunc ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( @@ -257,7 +261,7 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResourceFunc ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( @@ -292,7 +296,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResourceFunc ) private[this] val joinRow = new JoinedRow @@ -328,7 +332,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResourceFunc ) private[this] val joinRow = new JoinedRow @@ -371,7 +375,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResourceFunc ) private[this] val joinRow = new JoinedRow diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 7d2d93fa194f..f07c0fadec4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1049,7 +1049,7 @@ class JoinWithResourceCleanSuite extends JoinSuite with BeforeAndAfterAll { import testImplicits._ import scala.collection.mutable.ArrayBuffer - private def checkCleanupResourceTriggered(plan: SparkPlan) : ArrayBuffer[SortExec] = { + private def checkCleanupResourceTriggered(plan: SparkPlan): ArrayBuffer[SortExec] = { // Check cleanupResources are finally triggered in SortExec node val sorts = new ArrayBuffer[SortExec]() plan.foreachUp { From ec0f160033a387f4691c5e4abc719a55ea37b074 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sun, 20 Oct 2019 10:36:00 +0800 Subject: [PATCH 3/8] fix java doc --- sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f07c0fadec4c..0c1fcfc0a7b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1064,7 +1064,7 @@ class JoinWithResourceCleanSuite extends JoinSuite with BeforeAndAfterAll { sorts } - override def checkAnswer(df: => DataFrame, rows: Seq[Row]): Unit = { + override protected def checkAnswer(df: => DataFrame, rows: Seq[Row]): Unit = { withSQLConf( SQLConf.SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES.key -> "true") { checkCleanupResourceTriggered(df.queryExecution.sparkPlan) From 93815f8ca4d40d31d0c7f95ce98b98c67a4aac1f Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sun, 20 Oct 2019 14:52:42 +0800 Subject: [PATCH 4/8] ut fix --- .../src/main/scala/org/apache/spark/sql/execution/SortExec.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index c520495d5a1b..979da8a18497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -71,7 +71,6 @@ case class SortExec( * should make it public. */ def createSorter(): UnsafeExternalRowSorter = { - assert(rowSorter == null) val ordering = newOrdering(sortOrder, output) // The comparator for comparing prefix From defaaf26d8510243ea6d3375a35e8e2284b678cc Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 21 Oct 2019 13:41:10 +0800 Subject: [PATCH 5/8] delete the config and add a new test case --- .../apache/spark/sql/internal/SQLConf.scala | 11 -------- .../execution/joins/SortMergeJoinExec.scala | 27 ++++++------------- .../org/apache/spark/sql/JoinSuite.scala | 22 ++++++++++----- 3 files changed, 23 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 98633f139c07..4944099fcc0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1437,14 +1437,6 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) - val SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES = - buildConf("spark.sql.sortMergeJoinExec.eagerCleanupResources") - .internal() - .doc("When true, the SortMergeJoinExec will trigger all upstream resources cleanup right " + - "after it finishes computing.") - .booleanConf - .createWithDefault(true) - val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") .internal() @@ -2432,9 +2424,6 @@ class SQLConf extends Serializable with Logging { def sortMergeJoinExecBufferSpillThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) - def sortMergeJoinExecEagerCleanupResources: Boolean = - getConf(SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES) - def cartesianProductExecBufferInMemoryThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ad9f25ad5645..26fb0e5ffb1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -161,19 +161,10 @@ case class SortMergeJoinExec( sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold } - private def needEagerCleanup: Boolean = { - sqlContext.conf.sortMergeJoinExecEagerCleanupResources - } - protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold - val cleanupResourceFunc: () => Unit = if (needEagerCleanup) { - cleanupResources - } else { - () => {} - } left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -201,7 +192,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResourceFunc + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -246,7 +237,7 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResourceFunc + cleanupResources ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( @@ -261,7 +252,7 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold, - cleanupResourceFunc + cleanupResources ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( @@ -296,7 +287,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResourceFunc + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -332,7 +323,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResourceFunc + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -375,7 +366,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResourceFunc + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -655,10 +646,8 @@ case class SortMergeJoinExec( (evaluateVariables(leftVars), "") } - val eagerCleanup = if (needEagerCleanup) { - val thisPlan = ctx.addReferenceObj("plan", this) - s"$thisPlan.cleanupResources();" - } else "" + val thisPlan = ctx.addReferenceObj("plan", this) + val eagerCleanup = s"$thisPlan.cleanupResources();" s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 0c1fcfc0a7b8..777bf7dffae1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1065,16 +1065,12 @@ class JoinWithResourceCleanSuite extends JoinSuite with BeforeAndAfterAll { } override protected def checkAnswer(df: => DataFrame, rows: Seq[Row]): Unit = { - withSQLConf( - SQLConf.SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES.key -> "true") { - checkCleanupResourceTriggered(df.queryExecution.sparkPlan) - super.checkAnswer(df, rows) - } + checkCleanupResourceTriggered(df.queryExecution.sparkPlan) + super.checkAnswer(df, rows) } - test("cleanupResource in code generation") { + test("cleanupResource with code generation") { withSQLConf( - SQLConf.SORT_MERGE_JOIN_EXEC_EAGER_CLEANUP_RESOURCES.key -> "true", SQLConf.SHUFFLE_PARTITIONS.key -> "1", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df1 = spark.range(0, 10, 1, 2) @@ -1089,4 +1085,16 @@ class JoinWithResourceCleanSuite extends JoinSuite with BeforeAndAfterAll { checkAnswer(res, Row(0, 0, 0)) } } + + test("cleanupResource without code generation") { + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 10, 1, 2) + val df2 = spark.range(10).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") + checkAnswer(res, Row(0, 0, 0)) + } + } } From 7787d45e17ef99b41b6e4d13b5904e0e7e0cfb2e Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 21 Oct 2019 15:29:11 +0800 Subject: [PATCH 6/8] simplify test case --- .../org/apache/spark/sql/JoinSuite.scala | 69 +++++++------------ 1 file changed, 23 insertions(+), 46 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 777bf7dffae1..64625d720a0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,11 +20,9 @@ package org.apache.spark.sql import java.util.Locale import scala.collection.JavaConverters._ -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ -import org.scalatest.BeforeAndAfterAll import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier @@ -41,6 +39,27 @@ import org.apache.spark.sql.types.StructType class JoinSuite extends QueryTest with SharedSparkSession { import testImplicits._ + private def checkCleanupResourceTriggered(plan: SparkPlan): ArrayBuffer[SortExec] = { + // SPARK-21492: Check cleanupResources are finally triggered in SortExec node for every + // test case + val sorts = new ArrayBuffer[SortExec]() + plan.foreachUp { + case s: SortExec => sorts += s + case _ => + } + sorts.foreach { sort => + val sortExec = spy(sort) + verify(sortExec, atLeastOnce).cleanupResources() + verify(sortExec.rowSorter, atLeastOnce).cleanupResources() + } + sorts + } + + override protected def checkAnswer(df: => DataFrame, rows: Seq[Row]): Unit = { + checkCleanupResourceTriggered(df.queryExecution.sparkPlan) + super.checkAnswer(df, rows) + } + setupTestData() def statisticSizeInByte(df: DataFrame): BigInt = { @@ -1043,50 +1062,8 @@ class JoinSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(1, 2, 1, 2) :: Nil) } -} - -class JoinWithResourceCleanSuite extends JoinSuite with BeforeAndAfterAll { - import testImplicits._ - import scala.collection.mutable.ArrayBuffer - - private def checkCleanupResourceTriggered(plan: SparkPlan): ArrayBuffer[SortExec] = { - // Check cleanupResources are finally triggered in SortExec node - val sorts = new ArrayBuffer[SortExec]() - plan.foreachUp { - case s: SortExec => sorts += s - case _ => - } - sorts.foreach { sort => - val sortExec = spy(sort) - verify(sortExec, atLeastOnce).cleanupResources() - verify(sortExec.rowSorter, atLeastOnce).cleanupResources() - } - sorts - } - - override protected def checkAnswer(df: => DataFrame, rows: Seq[Row]): Unit = { - checkCleanupResourceTriggered(df.queryExecution.sparkPlan) - super.checkAnswer(df, rows) - } - - test("cleanupResource with code generation") { - withSQLConf( - SQLConf.SHUFFLE_PARTITIONS.key -> "1", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.range(0, 10, 1, 2) - val df2 = spark.range(10).select($"id".as("b1"), (- $"id").as("b2")) - val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") - - val sorts = checkCleanupResourceTriggered(res.queryExecution.sparkPlan) - // Make sure SortExec did the code generation - sorts.foreach { sort => - verify(spy(sort), atLeastOnce).doProduce(any()) - } - checkAnswer(res, Row(0, 0, 0)) - } - } - test("cleanupResource without code generation") { + test("SPARK-21492: cleanupResource without code generation") { withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", SQLConf.SHUFFLE_PARTITIONS.key -> "1", From b41f33aaaf225d77561683aa13a9a905863d1088 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 21 Oct 2019 16:06:57 +0800 Subject: [PATCH 7/8] comment address --- .../main/scala/org/apache/spark/sql/execution/SortExec.scala | 2 +- sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 979da8a18497..32d21d05e5f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -135,7 +135,7 @@ case class SortExec( // Name of sorter variable used in codegen. private var sorterVariable: String = _ - override protected[sql] def doProduce(ctx: CodegenContext): String = { + override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 64625d720a0e..1f1e15e7dfa2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.StructType class JoinSuite extends QueryTest with SharedSparkSession { import testImplicits._ - private def checkCleanupResourceTriggered(plan: SparkPlan): ArrayBuffer[SortExec] = { + private def attachCleanupResourceChecker(plan: SparkPlan): Unit = { // SPARK-21492: Check cleanupResources are finally triggered in SortExec node for every // test case val sorts = new ArrayBuffer[SortExec]() @@ -52,11 +52,10 @@ class JoinSuite extends QueryTest with SharedSparkSession { verify(sortExec, atLeastOnce).cleanupResources() verify(sortExec.rowSorter, atLeastOnce).cleanupResources() } - sorts } override protected def checkAnswer(df: => DataFrame, rows: Seq[Row]): Unit = { - checkCleanupResourceTriggered(df.queryExecution.sparkPlan) + attachCleanupResourceChecker(df.queryExecution.sparkPlan) super.checkAnswer(df, rows) } From 6d6dd5a942eb04e33a2b9d6e47a5ac041b8281b8 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 21 Oct 2019 16:10:34 +0800 Subject: [PATCH 8/8] simplify --- .../test/scala/org/apache/spark/sql/JoinSuite.scala | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 1f1e15e7dfa2..62f2d21e5270 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.util.Locale import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.ListBuffer import org.mockito.Mockito._ @@ -42,16 +42,13 @@ class JoinSuite extends QueryTest with SharedSparkSession { private def attachCleanupResourceChecker(plan: SparkPlan): Unit = { // SPARK-21492: Check cleanupResources are finally triggered in SortExec node for every // test case - val sorts = new ArrayBuffer[SortExec]() plan.foreachUp { - case s: SortExec => sorts += s + case s: SortExec => + val sortExec = spy(s) + verify(sortExec, atLeastOnce).cleanupResources() + verify(sortExec.rowSorter, atLeastOnce).cleanupResources() case _ => } - sorts.foreach { sort => - val sortExec = spy(sort) - verify(sortExec, atLeastOnce).cleanupResources() - verify(sortExec.rowSorter, atLeastOnce).cleanupResources() - } } override protected def checkAnswer(df: => DataFrame, rows: Seq[Row]): Unit = {