From 6797f37b58f74fe0a9fdcf8441d617647ef501b9 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 13 Dec 2024 14:44:03 +0100 Subject: [PATCH] Fix race condition in InMemoryRelation The previous code had a race condition that mean that `cachedColumnBuffers` could return `null` if another thread was concurrently was calling `clearCache`. The bug is caused by us checking _cachedColumnBuffers and return it as two separate operations outside a synchronized block. So it possible for another thread to set it to `null` after the check but before the return. ``` java.lang.NullPointerException: null at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.filteredCachedBatches(InMemoryTableScanExec.scala:156) at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.inputRDD$lzycompute(InMemoryTableScanExec.scala:98) at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.inputRDD(InMemoryTableScanExec.scala:84) at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.doExecute(InMemoryTableScanExec.scala:163) at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:195) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:246) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:243) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:191) at org.apache.spark.sql.execution.InputAdapter.inputRDD(WholeStageCodegenExec.scala:527) at org.apache.spark.sql.execution.InputRDDCodegen.inputRDDs(WholeStageCodegenExec.scala:455) at org.apache.spark.sql.execution.InputRDDCodegen.inputRDDs$(WholeStageCodegenExec.scala:454) at org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:498) at org.apache.spark.sql.execution.ProjectExec.inputRDDs(basicPhysicalOperators.scala:51) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:751) at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:195) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:246) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:243) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:191) at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:364) at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:445) at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:4218) at org.apache.spark.sql.Dataset.$anonfun$collect$1(Dataset.scala:3459) at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4208) at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:526) ... 23 more ``` --- .../execution/columnar/InMemoryRelation.scala | 14 ++--- .../columnar/InMemoryColumnarQuerySuite.scala | 51 +++++++++++++++++++ 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 0f280d236203..04717b683c49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -226,14 +226,16 @@ case class CachedRDDBuilder( } def cachedColumnBuffers: RDD[CachedBatch] = { - if (_cachedColumnBuffers == null) { - synchronized { - if (_cachedColumnBuffers == null) { - _cachedColumnBuffers = buildBuffers() - } + val cached = _cachedColumnBuffers + if (cached != null) { + return cached + } + synchronized { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildBuffers() } + _cachedColumnBuffers } - _cachedColumnBuffers } def clearCache(blocking: Boolean = false): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 4ea945d105e7..8ddd92dfe931 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -619,4 +619,55 @@ class InMemoryColumnarQuerySuite extends QueryTest assert(exceptionCnt.get == 0) } + + test("SPARK-50572: InMemoryRelation.cachedColumnBuffers should be thread-safe") { + val qe = spark.range(1).queryExecution + val plan = qe.executedPlan + val serializer = new TestCachedBatchSerializer(true, 1) + val cachedRDDBuilder = CachedRDDBuilder(serializer, MEMORY_ONLY, plan, None, qe.logical) + + @volatile var stopped = false + + val th1 = new Thread { + override def run(): Unit = { + while (!stopped) { + assert(cachedRDDBuilder.cachedColumnBuffers != null) + } + } + } + + val th2 = new Thread { + override def run(): Unit = { + while (!stopped) { + cachedRDDBuilder.clearCache() + } + } + } + + val th3 = new Thread { + override def run(): Unit = { + Thread.sleep(3000L) + stopped = true + } + } + + val exceptionCnt = new AtomicInteger + val exceptionHandler: Thread.UncaughtExceptionHandler = (_: Thread, cause: Throwable) => { + exceptionCnt.incrementAndGet + fail(cause) + } + + th1.setUncaughtExceptionHandler(exceptionHandler) + th2.setUncaughtExceptionHandler(exceptionHandler) + th1.start() + th2.start() + th3.start() + th1.join() + th2.join() + th3.join() + + cachedRDDBuilder.clearCache() + + assert(exceptionCnt.get == 0) + } }