Skip to content

Commit cf38fe0

Browse files
viiryamarmbrus
authored andcommitted
[SPARK-6844][SQL] Clean up accumulators used in InMemoryRelation when it is uncached
JIRA: https://issues.apache.org/jira/browse/SPARK-6844 Author: Liang-Chi Hsieh <[email protected]> Closes #5475 from viirya/cache_memory_leak and squashes the following commits: 0b41235 [Liang-Chi Hsieh] fix style. dc1d5d5 [Liang-Chi Hsieh] For comments. 78af229 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cache_memory_leak 26c9bb6 [Liang-Chi Hsieh] Add configuration to enable in-memory table scan accumulators. 1c3b06e [Liang-Chi Hsieh] Clean up accumulators used in InMemoryRelation when it is uncached.
1 parent 8584276 commit cf38fe0

File tree

4 files changed

+55
-14
lines changed

4 files changed

+55
-14
lines changed

sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
112112
val planToCache = query.queryExecution.analyzed
113113
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
114114
require(dataIndex >= 0, s"Table $query is not cached.")
115-
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
115+
cachedData(dataIndex).cachedRepresentation.uncache(blocking)
116116
cachedData.remove(dataIndex)
117117
}
118118

sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar
1919

2020
import java.nio.ByteBuffer
2121

22-
import org.apache.spark.Accumulator
22+
import org.apache.spark.{Accumulable, Accumulator, Accumulators}
2323
import org.apache.spark.sql.catalyst.expressions
2424

2525
import scala.collection.mutable.ArrayBuffer
26+
import scala.collection.mutable.HashMap
2627

2728
import org.apache.spark.rdd.RDD
2829
import org.apache.spark.sql.Row
30+
import org.apache.spark.SparkContext
2931
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
3032
import org.apache.spark.sql.catalyst.dsl.expressions._
3133
import org.apache.spark.sql.catalyst.expressions._
@@ -53,11 +55,16 @@ private[sql] case class InMemoryRelation(
5355
child: SparkPlan,
5456
tableName: Option[String])(
5557
private var _cachedColumnBuffers: RDD[CachedBatch] = null,
56-
private var _statistics: Statistics = null)
58+
private var _statistics: Statistics = null,
59+
private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null)
5760
extends LogicalPlan with MultiInstanceRelation {
5861

59-
private val batchStats =
60-
child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
62+
private val batchStats: Accumulable[ArrayBuffer[Row], Row] =
63+
if (_batchStats == null) {
64+
child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
65+
} else {
66+
_batchStats
67+
}
6168

6269
val partitionStatistics = new PartitionStatistics(output)
6370

@@ -161,7 +168,7 @@ private[sql] case class InMemoryRelation(
161168
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
162169
InMemoryRelation(
163170
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
164-
_cachedColumnBuffers, statisticsToBePropagated)
171+
_cachedColumnBuffers, statisticsToBePropagated, batchStats)
165172
}
166173

167174
override def children: Seq[LogicalPlan] = Seq.empty
@@ -175,13 +182,20 @@ private[sql] case class InMemoryRelation(
175182
child,
176183
tableName)(
177184
_cachedColumnBuffers,
178-
statisticsToBePropagated).asInstanceOf[this.type]
185+
statisticsToBePropagated,
186+
batchStats).asInstanceOf[this.type]
179187
}
180188

181189
def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
182190

183191
override protected def otherCopyArgs: Seq[AnyRef] =
184-
Seq(_cachedColumnBuffers, statisticsToBePropagated)
192+
Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats)
193+
194+
private[sql] def uncache(blocking: Boolean): Unit = {
195+
Accumulators.remove(batchStats.id)
196+
cachedColumnBuffers.unpersist(blocking)
197+
_cachedColumnBuffers = null
198+
}
185199
}
186200

187201
private[sql] case class InMemoryColumnarTableScan(
@@ -244,15 +258,20 @@ private[sql] case class InMemoryColumnarTableScan(
244258
}
245259
}
246260

261+
lazy val enableAccumulators: Boolean =
262+
sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean
263+
247264
// Accumulators used for testing purposes
248-
val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
249-
val readBatches: Accumulator[Int] = sparkContext.accumulator(0)
265+
lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
266+
lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0)
250267

251268
private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
252269

253270
override def execute(): RDD[Row] = {
254-
readPartitions.setValue(0)
255-
readBatches.setValue(0)
271+
if (enableAccumulators) {
272+
readPartitions.setValue(0)
273+
readBatches.setValue(0)
274+
}
256275

257276
relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator =>
258277
val partitionFilter = newPredicate(
@@ -302,7 +321,7 @@ private[sql] case class InMemoryColumnarTableScan(
302321
}
303322
}
304323

305-
if (rows.hasNext) {
324+
if (rows.hasNext && enableAccumulators) {
306325
readPartitions += 1
307326
}
308327

@@ -321,7 +340,9 @@ private[sql] case class InMemoryColumnarTableScan(
321340
logInfo(s"Skipping partition based on stats $statsString")
322341
false
323342
} else {
324-
readBatches += 1
343+
if (enableAccumulators) {
344+
readBatches += 1
345+
}
325346
true
326347
}
327348
}

sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.language.{implicitConversions, postfixOps}
2222

2323
import org.scalatest.concurrent.Eventually._
2424

25+
import org.apache.spark.Accumulators
2526
import org.apache.spark.sql.TestData._
2627
import org.apache.spark.sql.columnar._
2728
import org.apache.spark.sql.test.TestSQLContext._
@@ -297,4 +298,21 @@ class CachedTableSuite extends QueryTest {
297298
sql("Clear CACHE")
298299
assert(cacheManager.isEmpty)
299300
}
301+
302+
test("Clear accumulators when uncacheTable to prevent memory leaking") {
303+
val accsSize = Accumulators.originals.size
304+
305+
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
306+
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
307+
cacheTable("t1")
308+
cacheTable("t2")
309+
sql("SELECT * FROM t1").count()
310+
sql("SELECT * FROM t2").count()
311+
sql("SELECT * FROM t1").count()
312+
sql("SELECT * FROM t2").count()
313+
uncacheTable("t1")
314+
uncacheTable("t2")
315+
316+
assert(accsSize >= Accumulators.originals.size)
317+
}
300318
}

sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
3939

4040
// Enable in-memory partition pruning
4141
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
42+
// Enable in-memory table scan accumulators
43+
setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
4244
}
4345

4446
override protected def afterAll(): Unit = {

0 commit comments

Comments
 (0)