diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 0f280d236203..04717b683c49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -226,14 +226,16 @@ case class CachedRDDBuilder( } def cachedColumnBuffers: RDD[CachedBatch] = { - if (_cachedColumnBuffers == null) { - synchronized { - if (_cachedColumnBuffers == null) { - _cachedColumnBuffers = buildBuffers() - } + val cached = _cachedColumnBuffers + if (cached != null) { + return cached + } + synchronized { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildBuffers() } + _cachedColumnBuffers } - _cachedColumnBuffers } def clearCache(blocking: Boolean = false): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 4ea945d105e7..8ddd92dfe931 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -619,4 +619,55 @@ class InMemoryColumnarQuerySuite extends QueryTest assert(exceptionCnt.get == 0) } + + test("SPARK-50572: InMemoryRelation.cachedColumnBuffers should be thread-safe") { + val qe = spark.range(1).queryExecution + val plan = qe.executedPlan + val serializer = new TestCachedBatchSerializer(true, 1) + val cachedRDDBuilder = CachedRDDBuilder(serializer, MEMORY_ONLY, plan, None, qe.logical) + + @volatile var stopped = false + + val th1 = new Thread { + override def run(): Unit = { + while (!stopped) { + assert(cachedRDDBuilder.cachedColumnBuffers != null) + } + } + } + + val th2 = new Thread { + override def run(): Unit = { + while (!stopped) { + cachedRDDBuilder.clearCache() + } + } + } + + val th3 = new Thread { + override def run(): Unit = { + Thread.sleep(3000L) + stopped = true + } + } + + val exceptionCnt = new AtomicInteger + val exceptionHandler: Thread.UncaughtExceptionHandler = (_: Thread, cause: Throwable) => { + exceptionCnt.incrementAndGet + fail(cause) + } + + th1.setUncaughtExceptionHandler(exceptionHandler) + th2.setUncaughtExceptionHandler(exceptionHandler) + th1.start() + th2.start() + th3.start() + th1.join() + th2.join() + th3.join() + + cachedRDDBuilder.clearCache() + + assert(exceptionCnt.get == 0) + } }