Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2933,7 +2933,7 @@ class Dataset[T] private[sql](
*/
def storageLevel: StorageLevel = {
sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData =>
cachedData.cachedRepresentation.storageLevel
cachedData.cachedRepresentation.cacheBuilder.storageLevel
}.getOrElse(StorageLevel.NONE)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class CacheManager extends Logging {

/** Clears all cached tables. */
def clearCache(): Unit = writeLock {
cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache())
cachedData.clear()
}

Expand Down Expand Up @@ -119,7 +119,7 @@ class CacheManager extends Logging {
while (it.hasNext) {
val cd = it.next()
if (cd.plan.find(_.sameResult(plan)).isDefined) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
cd.cachedRepresentation.cacheBuilder.clearCache(blocking)
it.remove()
}
}
Expand All @@ -138,16 +138,14 @@ class CacheManager extends Logging {
while (it.hasNext) {
val cd = it.next()
if (condition(cd.plan)) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist()
cd.cachedRepresentation.cacheBuilder.clearCache()
// Remove the cache entry before we create a new one, so that we can have a different
// physical plan.
it.remove()
val plan = spark.sessionState.executePlan(cd.plan).executedPlan
val newCache = InMemoryRelation(
useCompression = cd.cachedRepresentation.useCompression,
batchSize = cd.cachedRepresentation.batchSize,
storageLevel = cd.cachedRepresentation.storageLevel,
child = spark.sessionState.executePlan(cd.plan).executedPlan,
tableName = cd.cachedRepresentation.tableName,
cacheBuilder = cd.cachedRepresentation
.cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null),
logicalPlan = cd.plan)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,6 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.LongAccumulator


object InMemoryRelation {
def apply(
useCompression: Boolean,
batchSize: Int,
storageLevel: StorageLevel,
child: SparkPlan,
tableName: Option[String],
logicalPlan: LogicalPlan): InMemoryRelation =
new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)(
statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
}


/**
* CachedBatch is a cached batch of rows.
*
Expand All @@ -55,58 +42,41 @@ object InMemoryRelation {
private[columnar]
case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)

case class InMemoryRelation(
output: Seq[Attribute],
case class CachedRDDBuilder(
useCompression: Boolean,
batchSize: Int,
storageLevel: StorageLevel,
@transient child: SparkPlan,
@transient cachedPlan: SparkPlan,
tableName: Option[String])(
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
statsOfPlanToCache: Statistics,
override val outputOrdering: Seq[SortOrder])
extends logical.LeafNode with MultiInstanceRelation {

override protected def innerChildren: Seq[SparkPlan] = Seq(child)

override def doCanonicalize(): logical.LogicalPlan =
copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)),
storageLevel = StorageLevel.NONE,
child = child.canonicalized,
tableName = None)(
_cachedColumnBuffers,
sizeInBytesStats,
statsOfPlanToCache,
outputOrdering)
@transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) {

override def producedAttributes: AttributeSet = outputSet

@transient val partitionStatistics = new PartitionStatistics(output)
val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator

override def computeStats(): Statistics = {
if (sizeInBytesStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
// Note that we should drop the hint info here. We may cache a plan whose root node is a hint
// node. When we lookup the cache with a semantically same plan without hint info, the plan
// returned by cache lookup should not have hint info. If we lookup the cache with a
// semantically same plan with a different hint info, `CacheManager.useCachedData` will take
// care of it and retain the hint info in the lookup input plan.
statsOfPlanToCache.copy(hints = HintInfo())
} else {
Statistics(sizeInBytes = sizeInBytesStats.value.longValue)
def cachedColumnBuffers: RDD[CachedBatch] = {
if (_cachedColumnBuffers == null) {
synchronized {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_cachedColumnBuffers is private[sql], so I'm not sure if this synchronized can be very effective.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel thread contention is low here, so I like simpler code. But, I welcome suggestions for more efficient&simpler code.

Copy link
Contributor

@cloud-fan cloud-fan Apr 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not care about thread-safety at all or do it right. Please prove CachedRDDBuilder will never be accessed by multiple threads and remove these synchronized, or making _cachedColumnBuffers private.

Copy link
Member Author

@maropu maropu Apr 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll recheck and update.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this pr w/o synchronized, I found multi-thread queries wrongly built four RDDs for a single cache;

val cachedDf = spark.range(1000000).selectExpr("id AS k", "id AS v").cache
for (i <- 0 to 3) {
  val thread = new Thread {
    override def run {
      // Start a job in each thread
      val df = cachedDf.filter('k > 5).groupBy().sum("v")
      df.collect
    }
  }
  thread.start
}

Either way, I think we should make _cachedColumnBuffers private, so I fixed.

if (_cachedColumnBuffers == null) {
_cachedColumnBuffers = buildBuffers()
}
}
}
_cachedColumnBuffers
}

// If the cached column buffers were not passed in, we calculate them in the constructor.
// As in Spark, the actual work of caching is lazy.
if (_cachedColumnBuffers == null) {
buildBuffers()
def clearCache(blocking: Boolean = true): Unit = {
if (_cachedColumnBuffers != null) {
synchronized {
if (_cachedColumnBuffers != null) {
_cachedColumnBuffers.unpersist(blocking)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we also do _cachedColumnBuffers = null so that unpersist won't be called twice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

_cachedColumnBuffers = null
}
}
}
}

private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitionsInternal { rowIterator =>
private def buildBuffers(): RDD[CachedBatch] = {
val output = cachedPlan.output
val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator =>
new Iterator[CachedBatch] {
def next(): CachedBatch = {
val columnBuilders = output.map { attribute =>
Expand Down Expand Up @@ -154,32 +124,77 @@ case class InMemoryRelation(

cached.setName(
tableName.map(n => s"In-memory table $n")
.getOrElse(StringUtils.abbreviate(child.toString, 1024)))
_cachedColumnBuffers = cached
.getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024)))
cached
}
}

object InMemoryRelation {

def apply(
useCompression: Boolean,
batchSize: Int,
storageLevel: StorageLevel,
child: SparkPlan,
tableName: Option[String],
logicalPlan: LogicalPlan): InMemoryRelation = {
val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)()
new InMemoryRelation(child.output, cacheBuilder)(
statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
}

def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = {
new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)(
statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
}
}

case class InMemoryRelation(
output: Seq[Attribute],
@transient cacheBuilder: CachedRDDBuilder)(
statsOfPlanToCache: Statistics,
override val outputOrdering: Seq[SortOrder])
extends logical.LeafNode with MultiInstanceRelation {

override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan)

override def doCanonicalize(): logical.LogicalPlan =
copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)),
cacheBuilder)(
statsOfPlanToCache,
outputOrdering)

override def producedAttributes: AttributeSet = outputSet

@transient val partitionStatistics = new PartitionStatistics(output)

def cachedPlan: SparkPlan = cacheBuilder.cachedPlan

override def computeStats(): Statistics = {
if (cacheBuilder.sizeInBytesStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
// Note that we should drop the hint info here. We may cache a plan whose root node is a hint
// node. When we lookup the cache with a semantically same plan without hint info, the plan
// returned by cache lookup should not have hint info. If we lookup the cache with a
// semantically same plan with a different hint info, `CacheManager.useCachedData` will take
// care of it and retain the hint info in the lookup input plan.
statsOfPlanToCache.copy(hints = HintInfo())
} else {
Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue)
}
}

def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering)
InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering)
}

override def newInstance(): this.type = {
new InMemoryRelation(
output.map(_.newInstance()),
useCompression,
batchSize,
storageLevel,
child,
tableName)(
_cachedColumnBuffers,
sizeInBytesStats,
cacheBuilder)(
statsOfPlanToCache,
outputOrdering).asInstanceOf[this.type]
}

def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers

override protected def otherCopyArgs: Seq[AnyRef] =
Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache)
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ case class InMemoryTableScanExec(
private def updateAttribute(expr: Expression): Expression = {
// attributes can be pruned so using relation's output.
// E.g., relation.output is [id, item] but this scan's output can be [item] only.
val attrMap = AttributeMap(relation.child.output.zip(relation.output))
val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output))
expr.transform {
case attr: Attribute => attrMap.getOrElse(attr, attr)
}
Expand All @@ -163,16 +163,16 @@ case class InMemoryTableScanExec(
// The cached version does not change the outputPartitioning of the original SparkPlan.
// But the cached version could alias output, so we need to replace output.
override def outputPartitioning: Partitioning = {
relation.child.outputPartitioning match {
relation.cachedPlan.outputPartitioning match {
case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning]
case _ => relation.child.outputPartitioning
case _ => relation.cachedPlan.outputPartitioning
}
}

// The cached version does not change the outputOrdering of the original SparkPlan.
// But the cached version could alias output, so we need to replace output.
override def outputOrdering: Seq[SortOrder] =
relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])
relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])

// Keeps relation's partition statistics because we don't serialize relation.
private val stats = relation.partitionStatistics
Expand Down Expand Up @@ -252,7 +252,7 @@ case class InMemoryTableScanExec(
// within the map Partitions closure.
val schema = stats.schema
val schemaIndex = schema.zipWithIndex
val buffers = relation.cachedColumnBuffers
val buffers = relation.cacheBuilder.cachedColumnBuffers

buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
val partitionFilter = newPredicate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.concurrent.duration._
import scala.language.postfixOps

import org.apache.spark.CleanerListener
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
Expand Down Expand Up @@ -52,7 +53,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
val plan = spark.table(tableName).queryExecution.sparkPlan
plan.collect {
case InMemoryTableScanExec(_, _, relation) =>
relation.cachedColumnBuffers.id
relation.cacheBuilder.cachedColumnBuffers.id
case _ =>
fail(s"Table $tableName is not cached\n" + plan)
}.head
Expand All @@ -78,7 +79,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = {
plan.collect {
case InMemoryTableScanExec(_, _, relation) =>
getNumInMemoryTablesRecursively(relation.child) + 1
getNumInMemoryTablesRecursively(relation.cachedPlan) + 1
}.sum
}

Expand Down Expand Up @@ -200,7 +201,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
spark.catalog.cacheTable("testData")
assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") {
spark.table("testData").queryExecution.withCachedData.collect {
case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r
case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r
}.size
}

Expand Down Expand Up @@ -367,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
val toBeCleanedAccIds = new HashSet[Long]

val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.sizeInBytesStats.id
case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id
}.head
toBeCleanedAccIds += accId1

val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.sizeInBytesStats.id
case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id
}.head
toBeCleanedAccIds += accId2

Expand Down Expand Up @@ -794,4 +795,29 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
}
}
}

private def checkIfNoJobTriggered[T](f: => T): T = {
var numJobTrigered = 0
val jobListener = new SparkListener {
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
numJobTrigered += 1
}
}
sparkContext.addSparkListener(jobListener)
try {
val result = f
sparkContext.listenerBus.waitUntilEmpty(10000L)
assert(numJobTrigered === 0)
result
} finally {
sparkContext.removeSparkListener(jobListener)
}
}

test("SPARK-23880 table cache should be lazy and don't trigger any jobs") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the changes in this PR, this test still can pass. : )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I'll recheck. Thanks!

val cachedData = checkIfNoJobTriggered {
spark.range(1002).filter('id > 1000).orderBy('id.desc).cache()
}
assert(cachedData.collect === Seq(1001))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class PlannerSuite extends SharedSQLContext {
test("CollectLimit can appear in the middle of a plan when caching is used") {
val query = testData.select('key, 'value).limit(2).cache()
val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation]
assert(planned.child.isInstanceOf[CollectLimitExec])
assert(planned.cachedPlan.isInstanceOf[CollectLimitExec])
}

test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None,
data.logicalPlan)

assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel)
inMemoryRelation.cachedColumnBuffers.collect().head match {
assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel)
inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match {
case _: CachedBatch =>
case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}")
}
Expand Down Expand Up @@ -337,7 +337,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(cached, expectedAnswer)

// Check that the right size was calculated.
assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
}

test("access primitive-type columns in CachedBatch without whole stage codegen") {
Expand Down
Loading