Skip to content

Commit 2e1c3dc

Browse files
cloud-fanandrej-db
andcommitted
[SPARK-50087] Robust handling of boolean expressions in CASE WHEN for MsSqlServer and future connectors
### What changes were proposed in this pull request? This PR proposes to propagate the `isPredicate` info in `V2ExpressionBuilder` and wrap the children of CASE WHEN expression (only `Predicate`s) with `IIF(<>, 1, 0)` for MsSqlServer. This is done to force returning an int instead of a boolean, as SqlServer cannot handle boolean expressions as a return type in CASE WHEN. E.g. ```CASE WHEN ... ELSE a = b END``` Old behavior: ```CASE WHEN ... ELSE a = b END = 1``` New behavior: Since in SqlServer a `= 1` is appended to the CASE WHEN, THEN and ELSE blocks must return an int. Therefore the final expression becomes: ```CASE WHEN ... ELSE IIF(a = b, 1, 0) END = 1``` ### Why are the changes needed? A user cannot work with an MsSqlServer data with CASE WHEN clauses or IF clauses if they wish to return a boolean value. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests to MsSqlServerIntegrationSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #48621 from andrej-db/SPARK-50087-CaseWhen. Lead-authored-by: Wenchen Fan <[email protected]> Co-authored-by: andrej-db <[email protected]> Co-authored-by: Andrej Gobeljić <[email protected]> Co-authored-by: andrej-gobeljic_data <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 136c722 commit 2e1c3dc

File tree

4 files changed

+114
-8
lines changed

4 files changed

+114
-8
lines changed

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ package org.apache.spark.sql.jdbc.v2
2020
import java.sql.Connection
2121

2222
import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException}
23+
import org.apache.spark.rdd.RDD
2324
import org.apache.spark.sql.AnalysisException
25+
import org.apache.spark.sql.catalyst.InternalRow
26+
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
27+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
2428
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
2529
import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker
2630
import org.apache.spark.sql.types._
@@ -37,6 +41,17 @@ import org.apache.spark.tags.DockerTest
3741
@DockerTest
3842
class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
3943

44+
def getExternalEngineQuery(executedPlan: SparkPlan): String = {
45+
getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery
46+
}
47+
48+
def getExternalEngineRdd(executedPlan: SparkPlan): RDD[InternalRow] = {
49+
val queryNode = executedPlan.collect { case r: RowDataSourceScanExec =>
50+
r
51+
}.head
52+
queryNode.rdd
53+
}
54+
4055
override def excluded: Seq[String] = Seq(
4156
"simple scan with OFFSET",
4257
"simple scan with LIMIT and OFFSET",
@@ -146,4 +161,68 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
146161
|""".stripMargin)
147162
assert(df.collect().length == 2)
148163
}
164+
165+
test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") {
166+
val df = sql(
167+
s"""|SELECT * FROM $catalogName.employee
168+
|WHERE CASE WHEN name = 'Legolas' THEN name = 'Elf' ELSE NOT (name = 'Wizard') END
169+
|""".stripMargin
170+
)
171+
172+
// scalastyle:off
173+
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
174+
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """
175+
)
176+
// scalastyle:on
177+
df.collect()
178+
}
179+
180+
test("SPARK-50087: SqlServer handle booleans in CASE WHEN with always true test") {
181+
val df = sql(
182+
s"""|SELECT * FROM $catalogName.employee
183+
|WHERE CASE WHEN (name = 'Legolas') THEN (name = 'Elf') ELSE (1=1) END
184+
|""".stripMargin
185+
)
186+
187+
// scalastyle:off
188+
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
189+
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """
190+
)
191+
// scalastyle:on
192+
df.collect()
193+
}
194+
195+
test("SPARK-50087: SqlServer handle booleans in nested CASE WHEN test") {
196+
val df = sql(
197+
s"""|SELECT * FROM $catalogName.employee
198+
|WHERE CASE WHEN (name = 'Legolas') THEN
199+
| CASE WHEN (name = 'Elf') THEN (name = 'Elrond') ELSE (name = 'Gandalf') END
200+
| ELSE (name = 'Sauron') END
201+
|""".stripMargin
202+
)
203+
204+
// scalastyle:off
205+
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
206+
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """
207+
)
208+
// scalastyle:on
209+
df.collect()
210+
}
211+
212+
test("SPARK-50087: SqlServer handle non-booleans in nested CASE WHEN test") {
213+
val df = sql(
214+
s"""|SELECT * FROM $catalogName.employee
215+
|WHERE CASE WHEN (name = 'Legolas') THEN
216+
| CASE WHEN (name = 'Elf') THEN 'Elf' ELSE 'Wizard' END
217+
| ELSE 'Sauron' END = name
218+
|""".stripMargin
219+
)
220+
221+
// scalastyle:off
222+
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
223+
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """
224+
)
225+
// scalastyle:on
226+
df.collect()
227+
}
149228
}

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
221221
case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate)
222222
case caseWhen @ CaseWhen(branches, elseValue) =>
223223
val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
224-
val values = branches.map(_._2).flatMap(generateExpression(_))
225-
val elseExprOpt = elseValue.flatMap(generateExpression(_))
224+
val values = branches.map(_._2).flatMap(generateExpression(_, isPredicate))
225+
val elseExprOpt = elseValue.flatMap(generateExpression(_, isPredicate))
226226
if (conditions.length == branches.length && values.length == branches.length &&
227227
elseExprOpt.size == elseValue.size) {
228228
val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
@@ -421,7 +421,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
421421
children: Seq[Expression],
422422
dataType: DataType,
423423
isPredicate: Boolean): Option[V2Expression] = {
424-
val childrenExpressions = children.flatMap(generateExpression(_))
424+
val childrenExpressions = children.flatMap(generateExpression(_, isPredicate))
425425
if (childrenExpressions.length == children.length) {
426426
if (isPredicate && dataType.isInstanceOf[BooleanType]) {
427427
Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression]))

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
4242
import org.apache.spark.sql.connector.catalog.index.TableIndex
4343
import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference}
4444
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
45+
import org.apache.spark.sql.connector.expressions.filter.Predicate
4546
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
4647
import org.apache.spark.sql.errors.QueryCompilationErrors
4748
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcOptionsInWrite, JdbcUtils}
@@ -377,6 +378,18 @@ abstract class JdbcDialect extends Serializable with Logging {
377378
}
378379

379380
private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder {
381+
// Some dialects do not support boolean type and this convenient util function is
382+
// provided to generate SQL string without boolean values.
383+
protected def inputToSQLNoBool(input: Expression): String = input match {
384+
case p: Predicate if p.name() == "ALWAYS_TRUE" => "1"
385+
case p: Predicate if p.name() == "ALWAYS_FALSE" => "0"
386+
case p: Predicate => predicateToIntSQL(inputToSQL(p))
387+
case _ => inputToSQL(input)
388+
}
389+
390+
protected def predicateToIntSQL(input: String): String =
391+
"CASE WHEN " + input + " THEN 1 ELSE 0 END"
392+
380393
override def visitLiteral(literal: Literal[_]): String = {
381394
Option(literal.value()).map(v =>
382395
compileValue(CatalystTypeConverters.convertToScala(v, literal.dataType())).toString)

sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr
5959
supportedFunctions.contains(funcName)
6060

6161
class MsSqlServerSQLBuilder extends JDBCSQLBuilder {
62+
override protected def predicateToIntSQL(input: String): String =
63+
"IIF(" + input + ", 1, 0)"
6264
override def visitSortOrder(
6365
sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = {
6466
(sortDirection, nullOrdering) match {
@@ -87,12 +89,24 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr
8789
expr match {
8890
case e: Predicate => e.name() match {
8991
case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" =>
90-
val Array(l, r) = e.children().map {
91-
case p: Predicate => s"CASE WHEN ${inputToSQL(p)} THEN 1 ELSE 0 END"
92-
case o => inputToSQL(o)
93-
}
92+
val Array(l, r) = e.children().map(inputToSQLNoBool)
9493
visitBinaryComparison(e.name(), l, r)
95-
case "CASE_WHEN" => visitCaseWhen(expressionsToStringArray(e.children())) + " = 1"
94+
case "CASE_WHEN" =>
95+
// Since MsSqlServer cannot handle boolean expressions inside
96+
// a CASE WHEN, it is necessary to convert those to another
97+
// CASE WHEN expression that will return 1 or 0 depending on
98+
// the result.
99+
// Example:
100+
// In: ... CASE WHEN a = b THEN c = d ... END
101+
// Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END ... END = 1
102+
val stringArray = e.children().grouped(2).flatMap {
103+
case Array(whenExpression, thenExpression) =>
104+
Array(inputToSQL(whenExpression), inputToSQLNoBool(thenExpression))
105+
case Array(elseExpression) =>
106+
Array(inputToSQLNoBool(elseExpression))
107+
}.toArray
108+
109+
visitCaseWhen(stringArray) + " = 1"
96110
case _ => super.build(expr)
97111
}
98112
case _ => super.build(expr)

0 commit comments

Comments
 (0)