Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
require(dataIndex >= 0, s"Table $query is not cached.")
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
cachedData(dataIndex).cachedRepresentation.uncache(blocking)
cachedData.remove(dataIndex)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar

import java.nio.ByteBuffer

import org.apache.spark.Accumulator
import org.apache.spark.{Accumulable, Accumulator, Accumulators}
import org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -53,11 +55,16 @@ private[sql] case class InMemoryRelation(
child: SparkPlan,
tableName: Option[String])(
private var _cachedColumnBuffers: RDD[CachedBatch] = null,
private var _statistics: Statistics = null)
private var _statistics: Statistics = null,
private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null)
extends LogicalPlan with MultiInstanceRelation {

private val batchStats =
child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
private val batchStats: Accumulable[ArrayBuffer[Row], Row] =
if (_batchStats == null) {
child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
} else {
_batchStats
}

val partitionStatistics = new PartitionStatistics(output)

Expand Down Expand Up @@ -161,7 +168,7 @@ private[sql] case class InMemoryRelation(
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
_cachedColumnBuffers, statisticsToBePropagated)
_cachedColumnBuffers, statisticsToBePropagated, batchStats)
}

override def children: Seq[LogicalPlan] = Seq.empty
Expand All @@ -175,13 +182,20 @@ private[sql] case class InMemoryRelation(
child,
tableName)(
_cachedColumnBuffers,
statisticsToBePropagated).asInstanceOf[this.type]
statisticsToBePropagated,
batchStats).asInstanceOf[this.type]
}

def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers

override protected def otherCopyArgs: Seq[AnyRef] =
Seq(_cachedColumnBuffers, statisticsToBePropagated)
Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats)

private[sql] def uncache(blocking: Boolean): Unit = {
Accumulators.remove(batchStats.id)
cachedColumnBuffers.unpersist(blocking)
_cachedColumnBuffers = null
}
}

private[sql] case class InMemoryColumnarTableScan(
Expand Down Expand Up @@ -244,15 +258,20 @@ private[sql] case class InMemoryColumnarTableScan(
}
}

lazy val enableAccumulators: Boolean =
sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean

// Accumulators used for testing purposes
val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
val readBatches: Accumulator[Int] = sparkContext.accumulator(0)
lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0)

private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning

override def execute(): RDD[Row] = {
readPartitions.setValue(0)
readBatches.setValue(0)
if (enableAccumulators) {
readPartitions.setValue(0)
readBatches.setValue(0)
}

relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator =>
val partitionFilter = newPredicate(
Expand Down Expand Up @@ -302,7 +321,7 @@ private[sql] case class InMemoryColumnarTableScan(
}
}

if (rows.hasNext) {
if (rows.hasNext && enableAccumulators) {
readPartitions += 1
}

Expand All @@ -321,7 +340,9 @@ private[sql] case class InMemoryColumnarTableScan(
logInfo(s"Skipping partition based on stats $statsString")
false
} else {
readBatches += 1
if (enableAccumulators) {
readBatches += 1
}
true
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.language.{implicitConversions, postfixOps}

import org.scalatest.concurrent.Eventually._

import org.apache.spark.Accumulators
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.test.TestSQLContext._
Expand Down Expand Up @@ -297,4 +298,21 @@ class CachedTableSuite extends QueryTest {
sql("Clear CACHE")
assert(cacheManager.isEmpty)
}

test("Clear accumulators when uncacheTable to prevent memory leaking") {
val accsSize = Accumulators.originals.size

sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to drop these two temp views after tests

cacheTable("t1")
cacheTable("t2")
sql("SELECT * FROM t1").count()
sql("SELECT * FROM t2").count()
sql("SELECT * FROM t1").count()
sql("SELECT * FROM t2").count()
uncacheTable("t1")
uncacheTable("t2")

assert(accsSize >= Accumulators.originals.size)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be

// Enable in-memory partition pruning
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
// Enable in-memory table scan accumulators
setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
Copy link
Member

@gatorsmile gatorsmile Jan 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not setting spark.sql.inMemoryTableScanStatistics.enable back false in afterAll?

}

override protected def afterAll(): Unit = {
Expand Down