diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala index 784bea899c4c8..e3ff7c5f05f0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.optimizer.ScalarSubqueryReference import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.BloomFilter @@ -58,6 +59,7 @@ case class BloomFilterMightContain( case GetStructField(subquery: PlanExpression[_], _, _) if !subquery.containsPattern(OUTER_REFERENCE) => TypeCheckResult.TypeCheckSuccess + case _: ScalarSubqueryReference => TypeCheckResult.TypeCheckSuccess case _ => DataTypeMismatch( errorSubClass = "BLOOM_FILTER_BINARY_OP_WRONG_TYPE", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala index fb279b1db6fc9..0a8e31b6bac4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala @@ -367,4 +367,23 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { checkNumBits(100, 2935) checkNumBits(1, 38) } + + test("SPARK-54336: Fix BloomFilterMightContain type check with ScalarSubqueryReference") { + val table = "bloom_filter_test" + withTempView(table) { + Seq(0).toDF("col").createOrReplaceTempView(table) + val df = sql( + s""" + |SELECT + | (SELECT + | first(might_contain( + | (SELECT bloom_filter_agg(col) FROM $table), + | 0L + | )) + | FROM $table) + |FROM $table + |""".stripMargin) + checkAnswer(df, Row(true)) + } + } }