From e0485fc398b51ee3b9bd14225316ac5302064cb5 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 7 Mar 2024 17:29:51 +0100 Subject: [PATCH 01/26] Fail all 10 regexp expressions for non-binary collations --- .../sql/catalyst/util/CollationFactory.java | 2 + .../catalyst/expressions/CollationUtils.scala | 86 +++++ .../expressions/regexpExpressions.scala | 84 +++++ .../org/apache/spark/sql/CollationSuite.scala | 321 ++++++++++++++++++ 4 files changed, 493 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationUtils.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index c0c011926be9..cca29b8670af 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -65,6 +65,7 @@ public static class Collation { * byte for byte equal. All accent or case-insensitive collations are considered non-binary. */ public final boolean isBinaryCollation; + public final boolean isLowercaseCollation; public Collation( String collationName, @@ -79,6 +80,7 @@ public Collation( this.version = version; this.hashFunction = hashFunction; this.isBinaryCollation = isBinaryCollation; + this.isLowercaseCollation = collationName.equals("UCS_BASIC_LCASE"); if (isBinaryCollation) { this.equalsFunction = UTF8String::equals; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationUtils.scala new file mode 100644 index 000000000000..27e9e9b0f207 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationUtils.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.types.{DataType, StringType} + +object CollationUtils { + def checkCollationCompatibility( + superCheck: => TypeCheckResult, + collationId: Int, + rightDataType: DataType + ): TypeCheckResult = { + val checkResult = superCheck + if (checkResult.isFailure) return checkResult + // Additional check needed for collation compatibility + val rightCollationId: Int = rightDataType.asInstanceOf[StringType].collationId + if (collationId != rightCollationId) { + return DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> CollationFactory.fetchCollation(collationId).collationName, + "collationNameRight" -> CollationFactory.fetchCollation(rightCollationId).collationName + ) + ) + } + TypeCheckResult.TypeCheckSuccess + } + + final val SUPPORT_BINARY_ONLY: Int = 0 + final val SUPPORT_LOWERCASE: Int = 1 + final val SUPPORT_ALL_COLLATIONS: Int = 2 + + def checkCollationSupport( + superCheck: => TypeCheckResult, + collationId: Int, + functionName: String, + supportLevel: Int = SUPPORT_BINARY_ONLY + ): TypeCheckResult = { + val checkResult = superCheck + if (checkResult.isFailure) return checkResult + // Additional check needed for collation support + val collation = CollationFactory.fetchCollation(collationId) + supportLevel match { + case SUPPORT_BINARY_ONLY => + if (!collation.isBinaryCollation) { + throwUnsupportedCollation(functionName, collation.collationName) + } + case SUPPORT_LOWERCASE => + if (!collation.isBinaryCollation && !collation.isLowercaseCollation) { + throwUnsupportedCollation(functionName, collation.collationName) + } + case SUPPORT_ALL_COLLATIONS => // No additional checks needed + case _ => throw new IllegalArgumentException("Invalid collation support level.") + } + TypeCheckResult.TypeCheckSuccess + } + + private def throwUnsupportedCollation(functionName: String, collationName: String): Unit = { + throw new SparkException( + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + messageParameters = Map( + "functionName" -> functionName, + "collationName" -> collationName), + cause = null + ) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b33de303b5d5..5516a7969c2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -46,6 +46,13 @@ abstract class StringRegexExpression extends BinaryExpression override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId + + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationCompatibility( + super.checkInputDataTypes(), collationId, right.dataType) + } + // try cache foldable pattern private lazy val cache: Pattern = right match { case p: Expression if p.foldable => @@ -130,6 +137,11 @@ abstract class StringRegexExpression extends BinaryExpression case class Like(left: Expression, right: Expression, escapeChar: Char) extends StringRegexExpression { + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationSupport( + super.checkInputDataTypes(), collationId, "like", CollationUtils.SUPPORT_BINARY_ONLY) + } + def this(left: Expression, right: Expression) = this(left, right, '\\') override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar) @@ -260,6 +272,15 @@ case class ILike( override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId + + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationCompatibility( + super.checkInputDataTypes(), collationId, right.dataType) + CollationUtils.checkCollationSupport( + super.checkInputDataTypes(), collationId, "ilike", CollationUtils.SUPPORT_BINARY_ONLY) + } + override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = { copy(left = newLeft, right = newRight) @@ -461,6 +482,11 @@ case class NotLikeAny(child: Expression, patterns: Seq[UTF8String]) extends Like // scalastyle:on line.contains.tab line.size.limit case class RLike(left: Expression, right: Expression) extends StringRegexExpression { + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationSupport( + super.checkInputDataTypes(), collationId, "rlike", CollationUtils.SUPPORT_BINARY_ONLY) + } + override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"RLIKE($left, $right)" @@ -545,6 +571,16 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) override def dataType: DataType = ArrayType(StringType, containsNull = false) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + + final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId + + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationCompatibility( + super.checkInputDataTypes(), collationId, second.dataType) + CollationUtils.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + } + override def first: Expression = str override def second: Expression = regex override def third: Expression = limit @@ -614,6 +650,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio def this(subject: Expression, regexp: Expression, rep: Expression) = this(subject, regexp, rep, Literal(1)) + final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId + override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() if (defaultCheck.isFailure) { @@ -643,6 +681,11 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ) ) } + + CollationUtils.checkCollationCompatibility( + super.checkInputDataTypes(), collationId, second.dataType) + CollationUtils.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) } // last regex in string, we will update the pattern iff regexp value changed. @@ -772,6 +815,13 @@ abstract class RegExpExtractBase final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + + final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId + + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationCompatibility( + super.checkInputDataTypes(), collationId, regexp.dataType) + } override def first: Expression = subject override def second: Expression = regexp override def third: Expression = idx @@ -849,6 +899,12 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio } override def dataType: DataType = StringType + + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + } + override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -948,6 +1004,11 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres } override def dataType: DataType = ArrayType(StringType) + + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + } override def prettyName: String = "regexp_extract_all" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -1022,6 +1083,15 @@ case class RegExpCount(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId + + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationCompatibility( + super.checkInputDataTypes(), collationId, right.dataType) + CollationUtils.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + } + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpCount = copy(left = newChildren(0), right = newChildren(1)) @@ -1061,6 +1131,15 @@ case class RegExpSubStr(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId + + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationCompatibility( + super.checkInputDataTypes(), collationId, right.dataType) + CollationUtils.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + } + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpSubStr = copy(left = newChildren(0), right = newChildren(1)) @@ -1113,6 +1192,11 @@ case class RegExpInStr(subject: Expression, regexp: Expression, idx: Expression) } override def dataType: DataType = IntegerType + + override def checkInputDataTypes(): TypeCheckResult = { + CollationUtils.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + } override def prettyName: String = "regexp_instr" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 2ff50722c3da..a8907e7599cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -390,6 +390,327 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }) } + case class CollationTestFail[R](left: String, right: String, collation: String) + + test("Support Like string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "%B%", "UCS_BASIC", true), + CollationTestCase("ABC", "%B%", "UNICODE", true) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') like " + + s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UCS_BASIC_LCASE", false), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') like " + + s"collate('${testCase.right}', '${testCase.collation}')") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "like", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support ILike string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "%b%", "UCS_BASIC", true), + CollationTestCase("ABC", "%b%", "UNICODE", true) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') ilike " + + s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UCS_BASIC_LCASE", false), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') ilike " + + s"collate('${testCase.right}', '${testCase.collation}')") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "ilike", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RLike string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", ".B.", "UCS_BASIC", true), + CollationTestCase("ABC", ".B.", "UNICODE", true) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') rlike " + + s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", ".b.", "UCS_BASIC_LCASE", false), + CollationTestCase("ABC", ".b.", "UNICODE_CI", false) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') rlike " + + s"collate('${testCase.right}', '${testCase.collation}')") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "rlike", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support StringSplit string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "[B]", "UCS_BASIC", 2), + CollationTestCase("ABC", "[B]", "UNICODE", 2) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT size(split(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}')))"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "b", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABC", "b", "UNICODE_CI", 0) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT size(split(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}')))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "split", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpReplace string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "AFFFE"), + CollationTestCase("ABCDE", ".C.", "UNICODE", "AFFFE") + ) + checks.foreach(testCase => { + checkAnswer( + sql(s"SELECT regexp_replace(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'),'FFF')"), + Row(testCase.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT regexp_replace(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'),'FFF')") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "regexp_replace", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpExtract string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "BCD"), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD") + ) + checks.foreach(testCase => { + checkAnswer( + sql(s"SELECT regexp_extract(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'),0)"), Row(testCase.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql( + s"SELECT regexp_extract(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'),0)") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "regexp_extract", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpExtractAll string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 1), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1) + ) + checks.foreach(testCase => { + checkAnswer( + sql(s"SELECT size(regexp_extract_all(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'),0))"), Row(testCase.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql( + s"SELECT size(regexp_extract_all(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'),0))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "regexp_extract_all", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpCount string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 1), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT regexp_count(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT regexp_count(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "regexp_count", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpSubStr string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "BCD"), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD") + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT regexp_substr(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT regexp_substr(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "regexp_substr", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpInStr string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 2), + CollationTestCase("ABCDE", ".C.", "UNICODE", 2) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT regexp_instr(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT regexp_instr(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "regexp_instr", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + test("aggregates count respects collation") { Seq( ("ucs_basic", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), From 355eee9ef32471456782dc8941dfe76793834109 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 8 Mar 2024 08:57:41 +0100 Subject: [PATCH 02/26] Reformat code --- ...s.scala => CollationTypeConstraints.scala} | 37 ++++++----- .../expressions/regexpExpressions.scala | 64 +++++++++++-------- 2 files changed, 57 insertions(+), 44 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{CollationUtils.scala => CollationTypeConstraints.scala} (77%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala similarity index 77% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala index 27e9e9b0f207..5382b96a5ba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala @@ -23,12 +23,11 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types.{DataType, StringType} -object CollationUtils { +object CollationTypeConstraints { def checkCollationCompatibility( - superCheck: => TypeCheckResult, - collationId: Int, - rightDataType: DataType - ): TypeCheckResult = { + superCheck: => TypeCheckResult, + collationId: Int, + rightDataType: DataType): TypeCheckResult = { val checkResult = superCheck if (checkResult.isFailure) return checkResult // Additional check needed for collation compatibility @@ -45,30 +44,34 @@ object CollationUtils { TypeCheckResult.TypeCheckSuccess } - final val SUPPORT_BINARY_ONLY: Int = 0 - final val SUPPORT_LOWERCASE: Int = 1 - final val SUPPORT_ALL_COLLATIONS: Int = 2 + object CollationSupportLevel extends Enumeration { + type CollationSupportLevel = Value + + val SUPPORT_BINARY_ONLY: Value = Value(0) + val SUPPORT_LOWERCASE: Value = Value(1) + val SUPPORT_ALL_COLLATIONS: Value = Value(2) + } def checkCollationSupport( - superCheck: => TypeCheckResult, - collationId: Int, - functionName: String, - supportLevel: Int = SUPPORT_BINARY_ONLY - ): TypeCheckResult = { + superCheck: => TypeCheckResult, + collationId: Int, + functionName: String, + collationSupportLevel: CollationSupportLevel.CollationSupportLevel) + : TypeCheckResult = { val checkResult = superCheck if (checkResult.isFailure) return checkResult // Additional check needed for collation support val collation = CollationFactory.fetchCollation(collationId) - supportLevel match { - case SUPPORT_BINARY_ONLY => + collationSupportLevel match { + case CollationSupportLevel.SUPPORT_BINARY_ONLY => if (!collation.isBinaryCollation) { throwUnsupportedCollation(functionName, collation.collationName) } - case SUPPORT_LOWERCASE => + case CollationSupportLevel.SUPPORT_LOWERCASE => if (!collation.isBinaryCollation && !collation.isLowercaseCollation) { throwUnsupportedCollation(functionName, collation.collationName) } - case SUPPORT_ALL_COLLATIONS => // No additional checks needed + case CollationSupportLevel.SUPPORT_ALL_COLLATIONS => // No additional checks needed case _ => throw new IllegalArgumentException("Invalid collation support level.") } TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 5516a7969c2b..dad38333228a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -49,7 +49,7 @@ abstract class StringRegexExpression extends BinaryExpression final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationCompatibility( + CollationTypeConstraints.checkCollationCompatibility( super.checkInputDataTypes(), collationId, right.dataType) } @@ -138,8 +138,9 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) extends StringRegexExpression { override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationSupport( - super.checkInputDataTypes(), collationId, "like", CollationUtils.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport( + super.checkInputDataTypes(), collationId, "like", + CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } def this(left: Expression, right: Expression) = this(left, right, '\\') @@ -275,10 +276,11 @@ case class ILike( final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationCompatibility( + CollationTypeConstraints.checkCollationCompatibility( super.checkInputDataTypes(), collationId, right.dataType) - CollationUtils.checkCollationSupport( - super.checkInputDataTypes(), collationId, "ilike", CollationUtils.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport( + super.checkInputDataTypes(), collationId, "ilike", + CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override protected def withNewChildrenInternal( @@ -483,8 +485,9 @@ case class NotLikeAny(child: Expression, patterns: Seq[UTF8String]) extends Like case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationSupport( - super.checkInputDataTypes(), collationId, "rlike", CollationUtils.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport( + super.checkInputDataTypes(), collationId, "rlike", + CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override def escape(v: String): String = v @@ -575,10 +578,11 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationCompatibility( + CollationTypeConstraints.checkCollationCompatibility( super.checkInputDataTypes(), collationId, second.dataType) - CollationUtils.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override def first: Expression = str @@ -682,10 +686,11 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ) } - CollationUtils.checkCollationCompatibility( + CollationTypeConstraints.checkCollationCompatibility( super.checkInputDataTypes(), collationId, second.dataType) - CollationUtils.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } // last regex in string, we will update the pattern iff regexp value changed. @@ -819,7 +824,7 @@ abstract class RegExpExtractBase final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationCompatibility( + CollationTypeConstraints.checkCollationCompatibility( super.checkInputDataTypes(), collationId, regexp.dataType) } override def first: Expression = subject @@ -901,8 +906,9 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def dataType: DataType = StringType override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override def prettyName: String = "regexp_extract" @@ -1006,8 +1012,9 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres override def dataType: DataType = ArrayType(StringType) override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override def prettyName: String = "regexp_extract_all" @@ -1086,10 +1093,11 @@ case class RegExpCount(left: Expression, right: Expression) final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationCompatibility( + CollationTypeConstraints.checkCollationCompatibility( super.checkInputDataTypes(), collationId, right.dataType) - CollationUtils.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override protected def withNewChildrenInternal( @@ -1134,10 +1142,11 @@ case class RegExpSubStr(left: Expression, right: Expression) final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationCompatibility( + CollationTypeConstraints.checkCollationCompatibility( super.checkInputDataTypes(), collationId, right.dataType) - CollationUtils.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override protected def withNewChildrenInternal( @@ -1194,8 +1203,9 @@ case class RegExpInStr(subject: Expression, regexp: Expression, idx: Expression) override def dataType: DataType = IntegerType override def checkInputDataTypes(): TypeCheckResult = { - CollationUtils.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, CollationUtils.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport( + super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override def prettyName: String = "regexp_instr" From cde9eff7a9cff341c1dada42482710962e5a2516 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 12 Mar 2024 13:11:50 +0100 Subject: [PATCH 03/26] Refactoring and better test coverage --- .../CollationTypeConstraints.scala | 58 +- .../expressions/regexpExpressions.scala | 67 +-- .../sql/CollationRegexpExpressionsSuite.scala | 517 ++++++++++++++++++ .../org/apache/spark/sql/CollationSuite.scala | 321 ----------- 4 files changed, 580 insertions(+), 383 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala index 5382b96a5ba1..4fb1a6325d52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala @@ -17,29 +17,48 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkIllegalArgumentException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.types.{DataType, StringType} object CollationTypeConstraints { + def checkCollationCompatibilityAndSupport( + checkResult: TypeCheckResult, + collationId: Int, + dataTypes: Seq[DataType], + functionName: String, + collationSupportLevel: CollationSupportLevel.CollationSupportLevel): TypeCheckResult = { + val resultCompatibility = checkCollationCompatibility(checkResult, collationId, dataTypes) + checkCollationSupport(resultCompatibility, collationId, functionName, collationSupportLevel) + } + def checkCollationCompatibility( - superCheck: => TypeCheckResult, + checkResult: TypeCheckResult, collationId: Int, - rightDataType: DataType): TypeCheckResult = { - val checkResult = superCheck - if (checkResult.isFailure) return checkResult + dataTypes: Seq[DataType]): TypeCheckResult = { + if (checkResult.isFailure) { + return checkResult + } + val collationName = CollationFactory.fetchCollation(collationId).collationName // Additional check needed for collation compatibility - val rightCollationId: Int = rightDataType.asInstanceOf[StringType].collationId - if (collationId != rightCollationId) { - return DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> CollationFactory.fetchCollation(collationId).collationName, - "collationNameRight" -> CollationFactory.fetchCollation(rightCollationId).collationName - ) - ) + for (dataType <- dataTypes) { + dataType match { + case stringType: StringType => + if (stringType.collationId != collationId) { + val collation = CollationFactory.fetchCollation(stringType.collationId) + return DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> collationName, + "collationNameRight" -> collation.collationName + ) + ) + } + case _ => + } } TypeCheckResult.TypeCheckSuccess } @@ -53,13 +72,14 @@ object CollationTypeConstraints { } def checkCollationSupport( - superCheck: => TypeCheckResult, + checkResult: TypeCheckResult, collationId: Int, functionName: String, collationSupportLevel: CollationSupportLevel.CollationSupportLevel) : TypeCheckResult = { - val checkResult = superCheck - if (checkResult.isFailure) return checkResult + if (checkResult.isFailure) { + return checkResult + } // Additional check needed for collation support val collation = CollationFactory.fetchCollation(collationId) collationSupportLevel match { @@ -72,7 +92,7 @@ object CollationTypeConstraints { throwUnsupportedCollation(functionName, collation.collationName) } case CollationSupportLevel.SUPPORT_ALL_COLLATIONS => // No additional checks needed - case _ => throw new IllegalArgumentException("Invalid collation support level.") + case _ => throw new SparkIllegalArgumentException("Invalid collation support level.") } TypeCheckResult.TypeCheckSuccess } @@ -81,7 +101,7 @@ object CollationTypeConstraints { throw new SparkException( errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", messageParameters = Map( - "functionName" -> functionName, + "functionName" -> toSQLId(functionName), "collationName" -> collationName), cause = null ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index dad38333228a..8e093f2ddb83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -49,8 +49,8 @@ abstract class StringRegexExpression extends BinaryExpression final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibility( - super.checkInputDataTypes(), collationId, right.dataType) + CollationTypeConstraints.checkCollationCompatibility(super.checkInputDataTypes(), collationId, + children.map(_.dataType)) } // try cache foldable pattern @@ -138,9 +138,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) extends StringRegexExpression { override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationSupport( - super.checkInputDataTypes(), collationId, "like", - CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport(super.checkInputDataTypes(), collationId, + prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } def this(left: Expression, right: Expression) = this(left, right, '\\') @@ -276,10 +275,8 @@ case class ILike( final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibility( - super.checkInputDataTypes(), collationId, right.dataType) - CollationTypeConstraints.checkCollationSupport( - super.checkInputDataTypes(), collationId, "ilike", + CollationTypeConstraints.checkCollationCompatibilityAndSupport( + super.checkInputDataTypes(), collationId, children.map(_.dataType), prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } @@ -485,9 +482,8 @@ case class NotLikeAny(child: Expression, patterns: Seq[UTF8String]) extends Like case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationSupport( - super.checkInputDataTypes(), collationId, "rlike", - CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport(super.checkInputDataTypes(), collationId, + prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override def escape(v: String): String = v @@ -578,10 +574,8 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibility( - super.checkInputDataTypes(), collationId, second.dataType) - CollationTypeConstraints.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.checkCollationCompatibilityAndSupport( + super.checkInputDataTypes(), collationId, children.map(_.dataType), prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } @@ -657,10 +651,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - val defaultCheck = super.checkInputDataTypes() - if (defaultCheck.isFailure) { - return defaultCheck - } if (!pos.foldable) { return DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", @@ -686,10 +676,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ) } - CollationTypeConstraints.checkCollationCompatibility( - super.checkInputDataTypes(), collationId, second.dataType) - CollationTypeConstraints.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.checkCollationCompatibilityAndSupport( + super.checkInputDataTypes(), collationId, children.map(_.dataType), prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } @@ -824,8 +812,8 @@ abstract class RegExpExtractBase final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibility( - super.checkInputDataTypes(), collationId, regexp.dataType) + CollationTypeConstraints.checkCollationCompatibility(super.checkInputDataTypes(), collationId, + children.map(_.dataType)) } override def first: Expression = subject override def second: Expression = regexp @@ -906,9 +894,8 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def dataType: DataType = StringType override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, - CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport(super.checkInputDataTypes(), collationId, + prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override def prettyName: String = "regexp_extract" @@ -1012,9 +999,8 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres override def dataType: DataType = ArrayType(StringType) override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, - CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport(super.checkInputDataTypes(), collationId, + prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override def prettyName: String = "regexp_extract_all" @@ -1093,10 +1079,8 @@ case class RegExpCount(left: Expression, right: Expression) final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibility( - super.checkInputDataTypes(), collationId, right.dataType) - CollationTypeConstraints.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.checkCollationCompatibilityAndSupport( + super.checkInputDataTypes(), collationId, children.map(_.dataType), prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } @@ -1142,10 +1126,8 @@ case class RegExpSubStr(left: Expression, right: Expression) final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibility( - super.checkInputDataTypes(), collationId, right.dataType) - CollationTypeConstraints.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, + CollationTypeConstraints.checkCollationCompatibilityAndSupport( + super.checkInputDataTypes(), collationId, children.map(_.dataType), prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } @@ -1203,9 +1185,8 @@ case class RegExpInStr(subject: Expression, regexp: Expression, idx: Expression) override def dataType: DataType = IntegerType override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationSupport( - super.checkInputDataTypes(), collationId, prettyName, - CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) + CollationTypeConstraints.checkCollationSupport(super.checkInputDataTypes(), collationId, + prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } override def prettyName: String = "regexp_instr" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala new file mode 100644 index 000000000000..18a09214228a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -0,0 +1,517 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.test.SharedSparkSession + +class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession { + + test("Check collation compatibility for regexp functions") { + // like + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT 'ABC' like collate('%b%', 'UNICODE_CI')") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"UCS_BASIC", + "collationNameRight" -> s"UNICODE_CI", + "sqlExpr" -> "\"ABC LIKE collate(%b%)\"" + ), + context = ExpectedContext(fragment = + s"like collate('%b%', 'UNICODE_CI')", + start = 13, stop = 45) + ) + // ilike + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT 'ABC' ilike collate('%b%', 'UNICODE_CI')") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"UCS_BASIC", + "collationNameRight" -> s"UNICODE_CI", + "sqlExpr" -> "\"ilike(ABC, collate(%b%))\"" + ), + context = ExpectedContext(fragment = + s"ilike collate('%b%', 'UNICODE_CI')", + start = 13, stop = 46) + ) + // rlike + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT 'ABC' rlike collate('%b%', 'UNICODE_CI')") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"UCS_BASIC", + "collationNameRight" -> s"UNICODE_CI", + "sqlExpr" -> "\"RLIKE(ABC, collate(%b%))\"" + ), + context = ExpectedContext(fragment = + s"rlike collate('%b%', 'UNICODE_CI')", + start = 13, stop = 46) + ) + // split + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT split('ABC', collate('[B]', 'UNICODE_CI'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"UCS_BASIC", + "collationNameRight" -> s"UNICODE_CI", + "sqlExpr" -> "\"split(ABC, collate([B]), -1)\"" + ), + context = ExpectedContext(fragment = + s"split('ABC', collate('[B]', 'UNICODE_CI'))", + start = 7, stop = 48) + ) + // regexp_replace + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT regexp_replace('ABCDE', collate('.c.', 'UNICODE_CI'), 'F')") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"UCS_BASIC", + "collationNameRight" -> s"UNICODE_CI", + "sqlExpr" -> "\"regexp_replace(ABCDE, collate(.c.), F, 1)\"" + ), + context = ExpectedContext(fragment = + s"regexp_replace('ABCDE', collate('.c.', 'UNICODE_CI'), 'F')", + start = 7, stop = 64) + ) + // regexp_extract + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT regexp_extract('ABCDE', collate('.c.', 'UNICODE_CI'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"UCS_BASIC", + "collationNameRight" -> s"UNICODE_CI", + "sqlExpr" -> "\"regexp_extract(ABCDE, collate(.c.), 1)\"" + ), + context = ExpectedContext(fragment = + s"regexp_extract('ABCDE', collate('.c.', 'UNICODE_CI'))", + start = 7, stop = 59) + ) + // regexp_extract_all + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT regexp_extract_all('ABCDE', collate('.c.', 'UNICODE_CI'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"UCS_BASIC", + "collationNameRight" -> s"UNICODE_CI", + "sqlExpr" -> "\"regexp_extract_all(ABCDE, collate(.c.), 1)\"" + ), + context = ExpectedContext(fragment = + s"regexp_extract_all('ABCDE', collate('.c.', 'UNICODE_CI'))", + start = 7, stop = 63) + ) + // regexp_count + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT regexp_count('ABCDE', collate('.c.', 'UNICODE_CI'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"UCS_BASIC", + "collationNameRight" -> s"UNICODE_CI", + "sqlExpr" -> "\"regexp_count(ABCDE, collate(.c.))\"" + ), + context = ExpectedContext(fragment = + s"regexp_count('ABCDE', collate('.c.', 'UNICODE_CI'))", + start = 7, stop = 57) + ) + // regexp_substr + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT regexp_substr('ABCDE', collate('.c.', 'UNICODE_CI'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"UCS_BASIC", + "collationNameRight" -> s"UNICODE_CI", + "sqlExpr" -> "\"regexp_substr(ABCDE, collate(.c.))\"" + ), + context = ExpectedContext(fragment = + s"regexp_substr('ABCDE', collate('.c.', 'UNICODE_CI'))", + start = 7, stop = 58) + ) + // regexp_instr + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT regexp_instr('ABCDE', collate('.c.', 'UNICODE_CI'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"UCS_BASIC", + "collationNameRight" -> s"UNICODE_CI", + "sqlExpr" -> "\"regexp_instr(ABCDE, collate(.c.), 0)\"" + ), + context = ExpectedContext(fragment = + s"regexp_instr('ABCDE', collate('.c.', 'UNICODE_CI'))", + start = 7, stop = 57) + ) + } + + case class CollationTestCase[R](left: String, right: String, collation: String, expectedResult: R) + case class CollationTestFail[R](left: String, right: String, collation: String) + + test("Support Like string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "%B%", "UCS_BASIC", true), + CollationTestCase("ABC", "%B%", "UNICODE", true) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') like " + + s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UCS_BASIC_LCASE", false), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') like " + + s"collate('${testCase.right}', '${testCase.collation}')") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "`like`", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support ILike string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "%b%", "UCS_BASIC", true), + CollationTestCase("ABC", "%b%", "UNICODE", true) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') ilike " + + s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UCS_BASIC_LCASE", false), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') ilike " + + s"collate('${testCase.right}', '${testCase.collation}')") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "`ilike`", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RLike string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", ".B.", "UCS_BASIC", true), + CollationTestCase("ABC", ".B.", "UNICODE", true) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') rlike " + + s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", ".b.", "UCS_BASIC_LCASE", false), + CollationTestCase("ABC", ".b.", "UNICODE_CI", false) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') rlike " + + s"collate('${testCase.right}', '${testCase.collation}')") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "`rlike`", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support StringSplit string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "[B]", "UCS_BASIC", 2), + CollationTestCase("ABC", "[B]", "UNICODE", 2) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT size(split(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}')))"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "b", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABC", "b", "UNICODE_CI", 0) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT size(split(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}')))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "`split`", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpReplace string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "AFFFE"), + CollationTestCase("ABCDE", ".C.", "UNICODE", "AFFFE") + ) + checks.foreach(testCase => { + checkAnswer( + sql(s"SELECT regexp_replace(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}')" + + s",collate('FFF', '${testCase.collation}'))"), + Row(testCase.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT regexp_replace(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}')" + + s",collate('FFF', '${testCase.collation}'))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "`regexp_replace`", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpExtract string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "BCD"), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD") + ) + checks.foreach(testCase => { + checkAnswer( + sql(s"SELECT regexp_extract(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'),0)"), + Row(testCase.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql( + s"SELECT regexp_extract(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'),0)") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "`regexp_extract`", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpExtractAll string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 1), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1) + ) + checks.foreach(testCase => { + checkAnswer( + sql(s"SELECT size(regexp_extract_all(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'),0))"), + Row(testCase.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql( + s"SELECT size(regexp_extract_all(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'),0))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "`regexp_extract_all`", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpCount string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 1), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT regexp_count(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT regexp_count(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "`regexp_count`", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpSubStr string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "BCD"), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD") + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT regexp_substr(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT regexp_substr(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "`regexp_substr`", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + + test("Support RegExpInStr string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 2), + CollationTestCase("ABCDE", ".C.", "UNICODE", 2) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT regexp_instr(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(testCase => { + checkError( + exception = intercept[SparkException] { + sql(s"SELECT regexp_instr(collate('${testCase.left}', '${testCase.collation}')" + + s",collate('${testCase.right}', '${testCase.collation}'))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "`regexp_instr`", + "collationName" -> s"${testCase.collation}" + ) + ) + }) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index a8907e7599cc..2ff50722c3da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -390,327 +390,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }) } - case class CollationTestFail[R](left: String, right: String, collation: String) - - test("Support Like string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABC", "%B%", "UCS_BASIC", true), - CollationTestCase("ABC", "%B%", "UNICODE", true) - ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') like " + - s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABC", "%b%", "UCS_BASIC_LCASE", false), - CollationTestCase("ABC", "%b%", "UNICODE_CI", false) - ) - fails.foreach(testCase => { - checkError( - exception = intercept[SparkException] { - sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') like " + - s"collate('${testCase.right}', '${testCase.collation}')") - }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", - parameters = Map( - "functionName" -> "like", - "collationName" -> s"${testCase.collation}" - ) - ) - }) - } - - test("Support ILike string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABC", "%b%", "UCS_BASIC", true), - CollationTestCase("ABC", "%b%", "UNICODE", true) - ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') ilike " + - s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABC", "%b%", "UCS_BASIC_LCASE", false), - CollationTestCase("ABC", "%b%", "UNICODE_CI", false) - ) - fails.foreach(testCase => { - checkError( - exception = intercept[SparkException] { - sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') ilike " + - s"collate('${testCase.right}', '${testCase.collation}')") - }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", - parameters = Map( - "functionName" -> "ilike", - "collationName" -> s"${testCase.collation}" - ) - ) - }) - } - - test("Support RLike string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABC", ".B.", "UCS_BASIC", true), - CollationTestCase("ABC", ".B.", "UNICODE", true) - ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') rlike " + - s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABC", ".b.", "UCS_BASIC_LCASE", false), - CollationTestCase("ABC", ".b.", "UNICODE_CI", false) - ) - fails.foreach(testCase => { - checkError( - exception = intercept[SparkException] { - sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') rlike " + - s"collate('${testCase.right}', '${testCase.collation}')") - }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", - parameters = Map( - "functionName" -> "rlike", - "collationName" -> s"${testCase.collation}" - ) - ) - }) - } - - test("Support StringSplit string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABC", "[B]", "UCS_BASIC", 2), - CollationTestCase("ABC", "[B]", "UNICODE", 2) - ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT size(split(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}')))"), Row(testCase.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABC", "b", "UCS_BASIC_LCASE", 0), - CollationTestCase("ABC", "b", "UNICODE_CI", 0) - ) - fails.foreach(testCase => { - checkError( - exception = intercept[SparkException] { - sql(s"SELECT size(split(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}')))") - }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", - parameters = Map( - "functionName" -> "split", - "collationName" -> s"${testCase.collation}" - ) - ) - }) - } - - test("Support RegExpReplace string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "AFFFE"), - CollationTestCase("ABCDE", ".C.", "UNICODE", "AFFFE") - ) - checks.foreach(testCase => { - checkAnswer( - sql(s"SELECT regexp_replace(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'),'FFF')"), - Row(testCase.expectedResult) - ) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") - ) - fails.foreach(testCase => { - checkError( - exception = intercept[SparkException] { - sql(s"SELECT regexp_replace(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'),'FFF')") - }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", - parameters = Map( - "functionName" -> "regexp_replace", - "collationName" -> s"${testCase.collation}" - ) - ) - }) - } - - test("Support RegExpExtract string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "BCD"), - CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD") - ) - checks.foreach(testCase => { - checkAnswer( - sql(s"SELECT regexp_extract(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'),0)"), Row(testCase.expectedResult) - ) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") - ) - fails.foreach(testCase => { - checkError( - exception = intercept[SparkException] { - sql( - s"SELECT regexp_extract(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'),0)") - }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", - parameters = Map( - "functionName" -> "regexp_extract", - "collationName" -> s"${testCase.collation}" - ) - ) - }) - } - - test("Support RegExpExtractAll string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 1), - CollationTestCase("ABCDE", ".C.", "UNICODE", 1) - ) - checks.foreach(testCase => { - checkAnswer( - sql(s"SELECT size(regexp_extract_all(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'),0))"), Row(testCase.expectedResult) - ) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) - ) - fails.foreach(testCase => { - checkError( - exception = intercept[SparkException] { - sql( - s"SELECT size(regexp_extract_all(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'),0))") - }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", - parameters = Map( - "functionName" -> "regexp_extract_all", - "collationName" -> s"${testCase.collation}" - ) - ) - }) - } - - test("Support RegExpCount string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 1), - CollationTestCase("ABCDE", ".C.", "UNICODE", 1) - ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT regexp_count(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) - ) - fails.foreach(testCase => { - checkError( - exception = intercept[SparkException] { - sql(s"SELECT regexp_count(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))") - }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", - parameters = Map( - "functionName" -> "regexp_count", - "collationName" -> s"${testCase.collation}" - ) - ) - }) - } - - test("Support RegExpSubStr string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "BCD"), - CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD") - ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT regexp_substr(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") - ) - fails.foreach(testCase => { - checkError( - exception = intercept[SparkException] { - sql(s"SELECT regexp_substr(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))") - }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", - parameters = Map( - "functionName" -> "regexp_substr", - "collationName" -> s"${testCase.collation}" - ) - ) - }) - } - - test("Support RegExpInStr string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 2), - CollationTestCase("ABCDE", ".C.", "UNICODE", 2) - ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT regexp_instr(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) - ) - fails.foreach(testCase => { - checkError( - exception = intercept[SparkException] { - sql(s"SELECT regexp_instr(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))") - }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", - parameters = Map( - "functionName" -> "regexp_instr", - "collationName" -> s"${testCase.collation}" - ) - ) - }) - } - test("aggregates count respects collation") { Seq( ("ucs_basic", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), From 616335b7aca34c579cc9b533003cce5394d80bf0 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 13 Mar 2024 09:08:07 +0100 Subject: [PATCH 04/26] Lockdown using StringType --- .../apache/spark/sql/types/StringType.scala | 18 ++++++++++++++++-- .../expressions/stringExpressions.scala | 2 +- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 313f525742ae..06505cd14496 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -40,6 +40,8 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa * equality and hashing). */ def isBinaryCollation: Boolean = CollationFactory.fetchCollation(collationId).isBinaryCollation + def isLowercaseCollation: Boolean = + CollationFactory.fetchCollation(collationId).isLowercaseCollation /** * Type name that is shown to the customer. @@ -54,8 +56,6 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa override def hashCode(): Int = collationId.hashCode() - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] - /** * The default size of a value of the StringType is 20 bytes. */ @@ -71,3 +71,17 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa case object StringType extends StringType(0) { def apply(collationId: Int): StringType = new StringType(collationId) } + +case object StringTypeBinaryLowercase extends AbstractDataType { + override private[sql] def defaultConcreteType: DataType = StringType + override private[sql] def simpleString: String = "string_bin_lcase" + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].isBinaryCollation || + other.asInstanceOf[StringType].isLowercaseCollation) +} + +case object StringTypeCollated extends AbstractDataType { + override private[sql] def defaultConcreteType: DataType = StringType + override private[sql] def simpleString: String = "string_collated" + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e6114ca277ca..6ef3e12f927e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -501,7 +501,7 @@ abstract class StringPredicate extends BinaryExpression def compare(l: UTF8String, r: UTF8String): Boolean - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeCollated, StringTypeCollated) override def checkInputDataTypes(): TypeCheckResult = { val checkResult = super.checkInputDataTypes() From 77e606c651997d5d455294c674b0798802811f2e Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:57:54 +0100 Subject: [PATCH 05/26] Lockdown for regexpExpressions --- .../sql/catalyst/util/CollationFactory.java | 2 +- .../apache/spark/sql/types/StringType.scala | 30 +- .../sql/catalyst/analysis/TypeCoercion.scala | 5 +- .../CollationTypeConstraints.scala | 54 -- .../expressions/regexpExpressions.scala | 79 +-- .../expressions/stringExpressions.scala | 19 +- .../sql/CollationRegexpExpressionsSuite.scala | 549 ++++++++---------- 7 files changed, 271 insertions(+), 467 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 442873ed0dae..3ecc2fd743cc 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -84,7 +84,7 @@ public Collation( this.version = version; this.hashFunction = hashFunction; this.isBinaryCollation = isBinaryCollation; - this.isLowercaseCollation = collationName.equals("UCS_BASIC_LCASE"); + this.isLowercaseCollation = collationName.equals("UTF8_BINARY_LCASE"); if (isBinaryCollation) { this.equalsFunction = UTF8String::equals; diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 3c1443cc3cb2..009656483884 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -40,8 +40,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa * equality and hashing). */ def isBinaryCollation: Boolean = CollationFactory.fetchCollation(collationId).isBinaryCollation - def isLowercaseCollation: Boolean = - CollationFactory.fetchCollation(collationId).isLowercaseCollation + def isLcaseCollation: Boolean = CollationFactory.fetchCollation(collationId).isLowercaseCollation /** * Type name that is shown to the customer. @@ -56,6 +55,9 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa override def hashCode(): Int = collationId.hashCode() + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isDefaultCollation + /** * The default size of a value of the StringType is 20 bytes. */ @@ -65,6 +67,8 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa } /** + * Use StringType for expressions supporting only binary collation. + * * @since 1.3.0 */ @Stable @@ -72,14 +76,30 @@ case object StringType extends StringType(0) { def apply(collationId: Int): StringType = new StringType(collationId) } -case object StringTypeBinaryLowercase extends AbstractDataType { +/** + * Use StringTypeBinary for expressions supporting only binary collation. + */ +case object StringTypeBinary extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = StringType - override private[sql] def simpleString: String = "string_bin_lcase" + override private[sql] def simpleString: String = "string_binary_lcase" + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isBinaryCollation +} + +/** + * Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation. + */ +case object StringTypeBinaryLcase extends AbstractDataType { + override private[sql] def defaultConcreteType: DataType = StringType + override private[sql] def simpleString: String = "string_binary_lcase" override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].isBinaryCollation || - other.asInstanceOf[StringType].isLowercaseCollation) + other.asInstanceOf[StringType].isLcaseCollation) } +/** + * Use StringTypeCollated for expressions supporting all possible collation types. + */ case object StringTypeCollated extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = StringType override private[sql] def simpleString: String = "string_collated" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 56e8843fda53..79f13c8891b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -702,9 +702,10 @@ abstract class TypeCoercionBase { }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => - val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { + case (expr: Expression, st2: StringType) if expr.dataType.isInstanceOf[StringType] => expr // If we cannot do the implicit cast, just use the original input. - implicitCast(in, expected).getOrElse(in) + case (in, expected) => implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala index 4fb1a6325d52..c9d8c7d635c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala @@ -17,23 +17,12 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.{SparkException, SparkIllegalArgumentException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.types.{DataType, StringType} object CollationTypeConstraints { - def checkCollationCompatibilityAndSupport( - checkResult: TypeCheckResult, - collationId: Int, - dataTypes: Seq[DataType], - functionName: String, - collationSupportLevel: CollationSupportLevel.CollationSupportLevel): TypeCheckResult = { - val resultCompatibility = checkCollationCompatibility(checkResult, collationId, dataTypes) - checkCollationSupport(resultCompatibility, collationId, functionName, collationSupportLevel) - } def checkCollationCompatibility( checkResult: TypeCheckResult, @@ -63,47 +52,4 @@ object CollationTypeConstraints { TypeCheckResult.TypeCheckSuccess } - object CollationSupportLevel extends Enumeration { - type CollationSupportLevel = Value - - val SUPPORT_BINARY_ONLY: Value = Value(0) - val SUPPORT_LOWERCASE: Value = Value(1) - val SUPPORT_ALL_COLLATIONS: Value = Value(2) - } - - def checkCollationSupport( - checkResult: TypeCheckResult, - collationId: Int, - functionName: String, - collationSupportLevel: CollationSupportLevel.CollationSupportLevel) - : TypeCheckResult = { - if (checkResult.isFailure) { - return checkResult - } - // Additional check needed for collation support - val collation = CollationFactory.fetchCollation(collationId) - collationSupportLevel match { - case CollationSupportLevel.SUPPORT_BINARY_ONLY => - if (!collation.isBinaryCollation) { - throwUnsupportedCollation(functionName, collation.collationName) - } - case CollationSupportLevel.SUPPORT_LOWERCASE => - if (!collation.isBinaryCollation && !collation.isLowercaseCollation) { - throwUnsupportedCollation(functionName, collation.collationName) - } - case CollationSupportLevel.SUPPORT_ALL_COLLATIONS => // No additional checks needed - case _ => throw new SparkIllegalArgumentException("Invalid collation support level.") - } - TypeCheckResult.TypeCheckSuccess - } - - private def throwUnsupportedCollation(functionName: String, collationName: String): Unit = { - throw new SparkException( - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - messageParameters = Map( - "functionName" -> toSQLId(functionName), - "collationName" -> collationName), - cause = null - ) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 8e093f2ddb83..e4ee405c6249 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -46,13 +46,6 @@ abstract class StringRegexExpression extends BinaryExpression override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId - - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibility(super.checkInputDataTypes(), collationId, - children.map(_.dataType)) - } - // try cache foldable pattern private lazy val cache: Pattern = right match { case p: Expression if p.foldable => @@ -137,11 +130,6 @@ abstract class StringRegexExpression extends BinaryExpression case class Like(left: Expression, right: Expression, escapeChar: Char) extends StringRegexExpression { - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationSupport(super.checkInputDataTypes(), collationId, - prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) - } - def this(left: Expression, right: Expression) = this(left, right, '\\') override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar) @@ -272,14 +260,6 @@ case class ILike( override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) - final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId - - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibilityAndSupport( - super.checkInputDataTypes(), collationId, children.map(_.dataType), prettyName, - CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) - } - override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = { copy(left = newLeft, right = newRight) @@ -481,11 +461,6 @@ case class NotLikeAny(child: Expression, patterns: Seq[UTF8String]) extends Like // scalastyle:on line.contains.tab line.size.limit case class RLike(left: Expression, right: Expression) extends StringRegexExpression { - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationSupport(super.checkInputDataTypes(), collationId, - prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) - } - override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"RLIKE($left, $right)" @@ -571,14 +546,6 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) override def dataType: DataType = ArrayType(StringType, containsNull = false) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) - final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId - - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibilityAndSupport( - super.checkInputDataTypes(), collationId, children.map(_.dataType), prettyName, - CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) - } - override def first: Expression = str override def second: Expression = regex override def third: Expression = limit @@ -648,9 +615,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio def this(subject: Expression, regexp: Expression, rep: Expression) = this(subject, regexp, rep, Literal(1)) - final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId - override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + return defaultCheck + } + if (!pos.foldable) { return DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", @@ -675,10 +645,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ) ) } - - CollationTypeConstraints.checkCollationCompatibilityAndSupport( - super.checkInputDataTypes(), collationId, children.map(_.dataType), prettyName, - CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) } // last regex in string, we will update the pattern iff regexp value changed. @@ -809,12 +775,6 @@ abstract class RegExpExtractBase override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) - final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId - - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibility(super.checkInputDataTypes(), collationId, - children.map(_.dataType)) - } override def first: Expression = subject override def second: Expression = regexp override def third: Expression = idx @@ -893,11 +853,6 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def dataType: DataType = StringType - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationSupport(super.checkInputDataTypes(), collationId, - prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) - } - override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -998,10 +953,6 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres override def dataType: DataType = ArrayType(StringType) - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationSupport(super.checkInputDataTypes(), collationId, - prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) - } override def prettyName: String = "regexp_extract_all" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -1076,14 +1027,6 @@ case class RegExpCount(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) - final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId - - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibilityAndSupport( - super.checkInputDataTypes(), collationId, children.map(_.dataType), prettyName, - CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) - } - override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpCount = copy(left = newChildren(0), right = newChildren(1)) @@ -1123,14 +1066,6 @@ case class RegExpSubStr(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) - final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId - - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibilityAndSupport( - super.checkInputDataTypes(), collationId, children.map(_.dataType), prettyName, - CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) - } - override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpSubStr = copy(left = newChildren(0), right = newChildren(1)) @@ -1184,10 +1119,6 @@ case class RegExpInStr(subject: Expression, regexp: Expression, idx: Expression) override def dataType: DataType = IntegerType - override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationSupport(super.checkInputDataTypes(), collationId, - prettyName, CollationTypeConstraints.CollationSupportLevel.SUPPORT_BINARY_ONLY) - } override def prettyName: String = "regexp_instr" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8582b8da6d9b..11205b0832be 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -504,23 +504,8 @@ abstract class StringPredicate extends BinaryExpression override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeCollated, StringTypeCollated) override def checkInputDataTypes(): TypeCheckResult = { - val checkResult = super.checkInputDataTypes() - if (checkResult.isFailure) { - return checkResult - } - // Additional check needed for collation compatibility - val rightCollationId: Int = right.dataType.asInstanceOf[StringType].collationId - if (collationId != rightCollationId) { - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> CollationFactory.fetchCollation(collationId).collationName, - "collationNameRight" -> CollationFactory.fetchCollation(rightCollationId).collationName - ) - ) - } else { - TypeCheckResult.TypeCheckSuccess - } + CollationTypeConstraints.checkCollationCompatibility(super.checkInputDataTypes(), collationId, + children.map(_.dataType)) } protected override def nullSafeEval(input1: Any, input2: Any): Any = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 18a09214228a..10127ef38c33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -19,204 +19,48 @@ package org.apache.spark.sql import scala.collection.immutable.Seq -import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.test.SharedSparkSession class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession { - test("Check collation compatibility for regexp functions") { - // like - checkError( - exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT 'ABC' like collate('%b%', 'UNICODE_CI')") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "collationNameLeft" -> s"UCS_BASIC", - "collationNameRight" -> s"UNICODE_CI", - "sqlExpr" -> "\"ABC LIKE collate(%b%)\"" - ), - context = ExpectedContext(fragment = - s"like collate('%b%', 'UNICODE_CI')", - start = 13, stop = 45) - ) - // ilike - checkError( - exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT 'ABC' ilike collate('%b%', 'UNICODE_CI')") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "collationNameLeft" -> s"UCS_BASIC", - "collationNameRight" -> s"UNICODE_CI", - "sqlExpr" -> "\"ilike(ABC, collate(%b%))\"" - ), - context = ExpectedContext(fragment = - s"ilike collate('%b%', 'UNICODE_CI')", - start = 13, stop = 46) - ) - // rlike - checkError( - exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT 'ABC' rlike collate('%b%', 'UNICODE_CI')") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "collationNameLeft" -> s"UCS_BASIC", - "collationNameRight" -> s"UNICODE_CI", - "sqlExpr" -> "\"RLIKE(ABC, collate(%b%))\"" - ), - context = ExpectedContext(fragment = - s"rlike collate('%b%', 'UNICODE_CI')", - start = 13, stop = 46) - ) - // split - checkError( - exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT split('ABC', collate('[B]', 'UNICODE_CI'))") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "collationNameLeft" -> s"UCS_BASIC", - "collationNameRight" -> s"UNICODE_CI", - "sqlExpr" -> "\"split(ABC, collate([B]), -1)\"" - ), - context = ExpectedContext(fragment = - s"split('ABC', collate('[B]', 'UNICODE_CI'))", - start = 7, stop = 48) - ) - // regexp_replace - checkError( - exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT regexp_replace('ABCDE', collate('.c.', 'UNICODE_CI'), 'F')") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "collationNameLeft" -> s"UCS_BASIC", - "collationNameRight" -> s"UNICODE_CI", - "sqlExpr" -> "\"regexp_replace(ABCDE, collate(.c.), F, 1)\"" - ), - context = ExpectedContext(fragment = - s"regexp_replace('ABCDE', collate('.c.', 'UNICODE_CI'), 'F')", - start = 7, stop = 64) - ) - // regexp_extract - checkError( - exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT regexp_extract('ABCDE', collate('.c.', 'UNICODE_CI'))") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "collationNameLeft" -> s"UCS_BASIC", - "collationNameRight" -> s"UNICODE_CI", - "sqlExpr" -> "\"regexp_extract(ABCDE, collate(.c.), 1)\"" - ), - context = ExpectedContext(fragment = - s"regexp_extract('ABCDE', collate('.c.', 'UNICODE_CI'))", - start = 7, stop = 59) - ) - // regexp_extract_all - checkError( - exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT regexp_extract_all('ABCDE', collate('.c.', 'UNICODE_CI'))") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "collationNameLeft" -> s"UCS_BASIC", - "collationNameRight" -> s"UNICODE_CI", - "sqlExpr" -> "\"regexp_extract_all(ABCDE, collate(.c.), 1)\"" - ), - context = ExpectedContext(fragment = - s"regexp_extract_all('ABCDE', collate('.c.', 'UNICODE_CI'))", - start = 7, stop = 63) - ) - // regexp_count - checkError( - exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT regexp_count('ABCDE', collate('.c.', 'UNICODE_CI'))") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "collationNameLeft" -> s"UCS_BASIC", - "collationNameRight" -> s"UNICODE_CI", - "sqlExpr" -> "\"regexp_count(ABCDE, collate(.c.))\"" - ), - context = ExpectedContext(fragment = - s"regexp_count('ABCDE', collate('.c.', 'UNICODE_CI'))", - start = 7, stop = 57) - ) - // regexp_substr - checkError( - exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT regexp_substr('ABCDE', collate('.c.', 'UNICODE_CI'))") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "collationNameLeft" -> s"UCS_BASIC", - "collationNameRight" -> s"UNICODE_CI", - "sqlExpr" -> "\"regexp_substr(ABCDE, collate(.c.))\"" - ), - context = ExpectedContext(fragment = - s"regexp_substr('ABCDE', collate('.c.', 'UNICODE_CI'))", - start = 7, stop = 58) - ) - // regexp_instr - checkError( - exception = intercept[ExtendedAnalysisException] { - spark.sql(s"SELECT regexp_instr('ABCDE', collate('.c.', 'UNICODE_CI'))") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "collationNameLeft" -> s"UCS_BASIC", - "collationNameRight" -> s"UNICODE_CI", - "sqlExpr" -> "\"regexp_instr(ABCDE, collate(.c.), 0)\"" - ), - context = ExpectedContext(fragment = - s"regexp_instr('ABCDE', collate('.c.', 'UNICODE_CI'))", - start = 7, stop = 57) - ) - } - - case class CollationTestCase[R](left: String, right: String, collation: String, expectedResult: R) - case class CollationTestFail[R](left: String, right: String, collation: String) + case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) + case class CollationTestFail[R](s1: String, s2: String, collation: String) test("Support Like string expression with Collation") { // Supported collations val checks = Seq( - CollationTestCase("ABC", "%B%", "UCS_BASIC", true), - CollationTestCase("ABC", "%B%", "UNICODE", true) + CollationTestCase("ABC", "%B%", "UTF8_BINARY", true) ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') like " + - s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) }) // Unsupported collations val fails = Seq( - CollationTestCase("ABC", "%b%", "UCS_BASIC_LCASE", false), + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%B%", "UNICODE", true), CollationTestCase("ABC", "%b%", "UNICODE_CI", false) ) - fails.foreach(testCase => { + fails.foreach(ct => { checkError( - exception = intercept[SparkException] { - sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') like " + - s"collate('${testCase.right}', '${testCase.collation}')") + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + + s"collate('${ct.s2}', '${ct.collation}')") }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", parameters = Map( - "functionName" -> "`like`", - "collationName" -> s"${testCase.collation}" + "sqlExpr" -> s"\"collate(${ct.s1}) LIKE collate(${ct.s2})\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"like collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 48 + 2 * ct.collation.length ) ) }) @@ -225,29 +69,37 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession test("Support ILike string expression with Collation") { // Supported collations val checks = Seq( - CollationTestCase("ABC", "%b%", "UCS_BASIC", true), - CollationTestCase("ABC", "%b%", "UNICODE", true) + CollationTestCase("ABC", "%b%", "UTF8_BINARY", true) ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') ilike " + - s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) }) // Unsupported collations val fails = Seq( - CollationTestCase("ABC", "%b%", "UCS_BASIC_LCASE", false), + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%b%", "UNICODE", true), CollationTestCase("ABC", "%b%", "UNICODE_CI", false) ) - fails.foreach(testCase => { + fails.foreach(ct => { checkError( - exception = intercept[SparkException] { - sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') ilike " + - s"collate('${testCase.right}', '${testCase.collation}')") + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + + s"collate('${ct.s2}', '${ct.collation}')") }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", parameters = Map( - "functionName" -> "`ilike`", - "collationName" -> s"${testCase.collation}" + "sqlExpr" -> s"\"ilike(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"ilike collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 49 + 2 * ct.collation.length ) ) }) @@ -256,29 +108,37 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession test("Support RLike string expression with Collation") { // Supported collations val checks = Seq( - CollationTestCase("ABC", ".B.", "UCS_BASIC", true), - CollationTestCase("ABC", ".B.", "UNICODE", true) + CollationTestCase("ABC", ".B.", "UTF8_BINARY", true) ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') rlike " + - s"collate('${testCase.right}', '${testCase.collation}')"), Row(testCase.expectedResult)) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) }) // Unsupported collations val fails = Seq( - CollationTestCase("ABC", ".b.", "UCS_BASIC_LCASE", false), + CollationTestCase("ABC", ".b.", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", ".B.", "UNICODE", true), CollationTestCase("ABC", ".b.", "UNICODE_CI", false) ) - fails.foreach(testCase => { + fails.foreach(ct => { checkError( - exception = intercept[SparkException] { - sql(s"SELECT collate('${testCase.left}', '${testCase.collation}') rlike " + - s"collate('${testCase.right}', '${testCase.collation}')") + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + + s"collate('${ct.s2}', '${ct.collation}')") }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", parameters = Map( - "functionName" -> "`rlike`", - "collationName" -> s"${testCase.collation}" + "sqlExpr" -> s"\"RLIKE(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"rlike collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 49 + 2 * ct.collation.length ) ) }) @@ -287,29 +147,38 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession test("Support StringSplit string expression with Collation") { // Supported collations val checks = Seq( - CollationTestCase("ABC", "[B]", "UCS_BASIC", 2), - CollationTestCase("ABC", "[B]", "UNICODE", 2) + CollationTestCase("ABC", "[B]", "UTF8_BINARY", 2) ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT size(split(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}')))"), Row(testCase.expectedResult)) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')))"), Row(ct.expectedResult)) }) // Unsupported collations val fails = Seq( - CollationTestCase("ABC", "b", "UCS_BASIC_LCASE", 0), - CollationTestCase("ABC", "b", "UNICODE_CI", 0) + CollationTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABC", "[B]", "UNICODE", 2), + CollationTestCase("ABC", "[b]", "UNICODE_CI", 0) ) - fails.foreach(testCase => { + fails.foreach(ct => { checkError( - exception = intercept[SparkException] { - sql(s"SELECT size(split(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}')))") + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')))") }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", parameters = Map( - "functionName" -> "`split`", - "collationName" -> s"${testCase.collation}" + "sqlExpr" -> s"\"split(collate(${ct.s1}), collate(${ct.s2}), -1)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"split(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 12, + stop = 55 + 2 * ct.collation.length ) ) }) @@ -318,34 +187,43 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession test("Support RegExpReplace string expression with Collation") { // Supported collations val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "AFFFE"), - CollationTestCase("ABCDE", ".C.", "UNICODE", "AFFFE") + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE") ) - checks.foreach(testCase => { + checks.foreach(ct => { checkAnswer( - sql(s"SELECT regexp_replace(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}')" + - s",collate('FFF', '${testCase.collation}'))"), - Row(testCase.expectedResult) + sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')" + + s",collate('FFF', '${ct.collation}'))"), + Row(ct.expectedResult) ) }) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "AFFFE"), CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") ) - fails.foreach(testCase => { + fails.foreach(ct => { checkError( - exception = intercept[SparkException] { - sql(s"SELECT regexp_replace(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}')" + - s",collate('FFF', '${testCase.collation}'))") + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')" + + s",collate('FFF', '${ct.collation}'))") }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", parameters = Map( - "functionName" -> "`regexp_replace`", - "collationName" -> s"${testCase.collation}" + "sqlExpr" -> s"\"regexp_replace(collate(${ct.s1}), collate(${ct.s2}), collate(FFF), 1)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_replace(collate('${ct.s1}', '${ct.collation}'),collate('${ct.s2}'," + + s" '${ct.collation}'),collate('FFF', '${ct.collation}'))", + start = 7, + stop = 80 + 3 * ct.collation.length ) ) }) @@ -354,33 +232,41 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession test("Support RegExpExtract string expression with Collation") { // Supported collations val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "BCD"), - CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD") + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") ) - checks.foreach(testCase => { + checks.foreach(ct => { checkAnswer( - sql(s"SELECT regexp_extract(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'),0)"), - Row(testCase.expectedResult) + sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0)"), + Row(ct.expectedResult) ) }) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") ) - fails.foreach(testCase => { + fails.foreach(ct => { checkError( - exception = intercept[SparkException] { - sql( - s"SELECT regexp_extract(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'),0)") + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0)") }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", parameters = Map( - "functionName" -> "`regexp_extract`", - "collationName" -> s"${testCase.collation}" + "sqlExpr" -> s"\"regexp_extract(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_extract(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'),0)", + start = 7, + stop = 63 + 2 * ct.collation.length ) ) }) @@ -389,33 +275,41 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession test("Support RegExpExtractAll string expression with Collation") { // Supported collations val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 1), - CollationTestCase("ABCDE", ".C.", "UNICODE", 1) + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) ) - checks.foreach(testCase => { + checks.foreach(ct => { checkAnswer( - sql(s"SELECT size(regexp_extract_all(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'),0))"), - Row(testCase.expectedResult) + sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0))"), + Row(ct.expectedResult) ) }) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1), CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) ) - fails.foreach(testCase => { + fails.foreach(ct => { checkError( - exception = intercept[SparkException] { - sql( - s"SELECT size(regexp_extract_all(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'),0))") + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', " + + s"'${ct.collation}'),collate('${ct.s2}', '${ct.collation}'),0))") }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", parameters = Map( - "functionName" -> "`regexp_extract_all`", - "collationName" -> s"${testCase.collation}" + "sqlExpr" -> s"\"regexp_extract_all(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_extract_all(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'),0)", + start = 12, + stop = 72 + 2 * ct.collation.length ) ) }) @@ -424,29 +318,38 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession test("Support RegExpCount string expression with Collation") { // Supported collations val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 1), - CollationTestCase("ABCDE", ".C.", "UNICODE", 1) + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT regexp_count(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) }) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1), CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) ) - fails.foreach(testCase => { + fails.foreach(ct => { checkError( - exception = intercept[SparkException] { - sql(s"SELECT regexp_count(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))") + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", parameters = Map( - "functionName" -> "`regexp_count`", - "collationName" -> s"${testCase.collation}" + "sqlExpr" -> s"\"regexp_count(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_count(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 59 + 2 * ct.collation.length ) ) }) @@ -455,29 +358,38 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession test("Support RegExpSubStr string expression with Collation") { // Supported collations val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", "BCD"), - CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD") + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT regexp_substr(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) }) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", ""), + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") ) - fails.foreach(testCase => { + fails.foreach(ct => { checkError( - exception = intercept[SparkException] { - sql(s"SELECT regexp_substr(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))") + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", parameters = Map( - "functionName" -> "`regexp_substr`", - "collationName" -> s"${testCase.collation}" + "sqlExpr" -> s"\"regexp_substr(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_substr(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 60 + 2 * ct.collation.length ) ) }) @@ -486,29 +398,38 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession test("Support RegExpInStr string expression with Collation") { // Supported collations val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UCS_BASIC", 2), - CollationTestCase("ABCDE", ".C.", "UNICODE", 2) + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 2) ) - checks.foreach(testCase => { - checkAnswer(sql(s"SELECT regexp_instr(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) }) // Unsupported collations val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UCS_BASIC_LCASE", 0), + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 2), CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) ) - fails.foreach(testCase => { + fails.foreach(ct => { checkError( - exception = intercept[SparkException] { - sql(s"SELECT regexp_instr(collate('${testCase.left}', '${testCase.collation}')" + - s",collate('${testCase.right}', '${testCase.collation}'))") + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") }, - errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", - sqlState = "0A000", + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", parameters = Map( - "functionName" -> "`regexp_instr`", - "collationName" -> s"${testCase.collation}" + "sqlExpr" -> s"\"regexp_instr(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_instr(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 59 + 2 * ct.collation.length ) ) }) From ee72ed372244518318ad23e4b5291b59f4f6e9ac Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 13 Mar 2024 13:21:43 +0100 Subject: [PATCH 06/26] Renaming and small fixes --- .../scala/org/apache/spark/sql/types/StringType.scala | 8 ++++---- .../sql/catalyst/expressions/collationExpressions.scala | 2 +- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 009656483884..47db597aa1be 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -81,7 +81,7 @@ case object StringType extends StringType(0) { */ case object StringTypeBinary extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = StringType - override private[sql] def simpleString: String = "string_binary_lcase" + override private[sql] def simpleString: String = "string_binary" override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isBinaryCollation } @@ -98,10 +98,10 @@ case object StringTypeBinaryLcase extends AbstractDataType { } /** - * Use StringTypeCollated for expressions supporting all possible collation types. + * Use StringTypeAnyCollation for expressions supporting all possible collation types. */ -case object StringTypeCollated extends AbstractDataType { +case object StringTypeAnyCollation extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = StringType - override private[sql] def simpleString: String = "string_collated" + override private[sql] def simpleString: String = "string_any_collation" override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index b0f77bad4483..8ef0280b728e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -82,7 +82,7 @@ case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) override def dataType: DataType = StringType(collationId) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override protected def withNewChildInternal( newChild: Expression): Expression = copy(newChild) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 11205b0832be..c3ed179d0571 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -501,7 +501,7 @@ abstract class StringPredicate extends BinaryExpression def compare(l: UTF8String, r: UTF8String): Boolean - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeCollated, StringTypeCollated) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def checkInputDataTypes(): TypeCheckResult = { CollationTypeConstraints.checkCollationCompatibility(super.checkInputDataTypes(), collationId, From 8bc10f54fb47e3a538725d6937573a0e0e020971 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 13 Mar 2024 14:02:15 +0100 Subject: [PATCH 07/26] Code style fixes --- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 6 ------ .../spark/sql/catalyst/expressions/stringExpressions.scala | 3 ++- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index e4ee405c6249..b33de303b5d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -545,7 +545,6 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) override def dataType: DataType = ArrayType(StringType, containsNull = false) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) - override def first: Expression = str override def second: Expression = regex override def third: Expression = limit @@ -620,7 +619,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio if (defaultCheck.isFailure) { return defaultCheck } - if (!pos.foldable) { return DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", @@ -774,7 +772,6 @@ abstract class RegExpExtractBase final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) - override def first: Expression = subject override def second: Expression = regexp override def third: Expression = idx @@ -852,7 +849,6 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio } override def dataType: DataType = StringType - override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -952,7 +948,6 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres } override def dataType: DataType = ArrayType(StringType) - override def prettyName: String = "regexp_extract_all" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -1118,7 +1113,6 @@ case class RegExpInStr(subject: Expression, regexp: Expression, idx: Expression) } override def dataType: DataType = IntegerType - override def prettyName: String = "regexp_instr" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c3ed179d0571..811d96d24836 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -501,7 +501,8 @@ abstract class StringPredicate extends BinaryExpression def compare(l: UTF8String, r: UTF8String): Boolean - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def checkInputDataTypes(): TypeCheckResult = { CollationTypeConstraints.checkCollationCompatibility(super.checkInputDataTypes(), collationId, From 8ee0f585903af85d3dd63a0b80d58be13081d6b9 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 13 Mar 2024 15:02:21 +0100 Subject: [PATCH 08/26] Move StringType lockdown condition to case object --- .../main/scala/org/apache/spark/sql/types/StringType.scala | 5 ++--- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 47db597aa1be..606f5be54a16 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -55,9 +55,6 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa override def hashCode(): Int = collationId.hashCode() - override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isDefaultCollation - /** * The default size of a value of the StringType is 20 bytes. */ @@ -74,6 +71,8 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa @Stable case object StringType extends StringType(0) { def apply(collationId: Int): StringType = new StringType(collationId) + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isDefaultCollation } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 79f13c8891b6..ec91e3ccac38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -703,7 +703,7 @@ abstract class TypeCoercionBase { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { - case (expr: Expression, st2: StringType) if expr.dataType.isInstanceOf[StringType] => expr + case (expr: Expression, StringType) if expr.dataType.isInstanceOf[StringType] => expr // If we cannot do the implicit cast, just use the original input. case (in, expected) => implicitCast(in, expected).getOrElse(in) } From d3a9e70406c05c0a7e853a0b0970df58edfaf8b2 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 14 Mar 2024 08:45:43 +0100 Subject: [PATCH 09/26] Fix checkInputDataTypes --- .../catalyst/expressions/CollationTypeConstraints.scala | 8 +------- .../sql/catalyst/expressions/stringExpressions.scala | 7 +++++-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala index c9d8c7d635c6..99a5a1eee60c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala @@ -24,13 +24,7 @@ import org.apache.spark.sql.types.{DataType, StringType} object CollationTypeConstraints { - def checkCollationCompatibility( - checkResult: TypeCheckResult, - collationId: Int, - dataTypes: Seq[DataType]): TypeCheckResult = { - if (checkResult.isFailure) { - return checkResult - } + def checkCollationCompatibility(collationId: Int, dataTypes: Seq[DataType]): TypeCheckResult = { val collationName = CollationFactory.fetchCollation(collationId).collationName // Additional check needed for collation compatibility for (dataType <- dataTypes) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 811d96d24836..8efc5ad5f7ba 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -505,8 +505,11 @@ abstract class StringPredicate extends BinaryExpression Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def checkInputDataTypes(): TypeCheckResult = { - CollationTypeConstraints.checkCollationCompatibility(super.checkInputDataTypes(), collationId, - children.map(_.dataType)) + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + return defaultCheck + } + CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType)) } protected override def nullSafeEval(input1: Any, input2: Any): Any = From c11c458463a04e516291b2d581c73ae033e2fac4 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 14 Mar 2024 12:02:51 +0100 Subject: [PATCH 10/26] Implicit cast for new StringTypes --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index ec91e3ccac38..6619cefa90b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -704,6 +704,9 @@ abstract class TypeCoercionBase { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (expr: Expression, StringType) if expr.dataType.isInstanceOf[StringType] => expr + case (expr: Expression, StringTypeBinary | StringTypeBinaryLcase | StringTypeAnyCollation) + if !expr.dataType.isInstanceOf[StringType] => + implicitCast(expr, StringType).getOrElse(expr) // If we cannot do the implicit cast, just use the original input. case (in, expected) => implicitCast(in, expected).getOrElse(in) } @@ -886,8 +889,8 @@ object TypeCoercion extends TypeCoercionBase { /** Promotes all the way to StringType. */ private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match { - case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) - case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) + case (st: StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(st) + case (t1: AtomicType, st: StringType) if t1 != BinaryType && t1 != BooleanType => Some(st) case _ => None } @@ -994,9 +997,9 @@ object TypeCoercion extends TypeCoercionBase { case (StringType, target: NumericType) => target case (StringType, datetime: DatetimeType) => datetime case (StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType - case (StringType, BinaryType) => BinaryType + case (_: StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, StringType) if any != StringType => StringType + case (any: AtomicType, st: StringType) if !any.isInstanceOf[StringType] => st // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. From 4778951a02a6b84771356e6e4d0e4a985e6cf7e8 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 14 Mar 2024 12:06:57 +0100 Subject: [PATCH 11/26] Implicit cast for new StringTypes --- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 6619cefa90b5..b27ef737946b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -889,8 +889,8 @@ object TypeCoercion extends TypeCoercionBase { /** Promotes all the way to StringType. */ private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match { - case (st: StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(st) - case (t1: AtomicType, st: StringType) if t1 != BinaryType && t1 != BooleanType => Some(st) + case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) + case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) case _ => None } @@ -997,9 +997,9 @@ object TypeCoercion extends TypeCoercionBase { case (StringType, target: NumericType) => target case (StringType, datetime: DatetimeType) => datetime case (StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType - case (_: StringType, BinaryType) => BinaryType + case (StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, st: StringType) if !any.isInstanceOf[StringType] => st + case (any: AtomicType, StringType) if any != StringType => StringType // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. From 1582f242ed7150bf83020ac4924af4e9f1b01500 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 14 Mar 2024 12:50:22 +0100 Subject: [PATCH 12/26] Separate suite for stringExpressions --- .../sql/CollationStringExpressionsSuite.scala | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala new file mode 100644 index 000000000000..b20aceb11b90 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.test.SharedSparkSession + +class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession { + + case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) + case class CollationTestFail[R](s1: String, s2: String, collation: String) + + test("Support ConcatWs string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("Spark", "SQL", "UTF8_BINARY", "Spark SQL") + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT concat_ws(collate(' ', '${ct.collation}'), " + + s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))"), + Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%B%", "UNICODE", true), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(ct => { + val expr = s"concat_ws(collate(' ', '${ct.collation}'), " + + s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))" + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT $expr") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"concat_ws(collate( ), collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate( )\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"$expr", + start = 7, + stop = 73 + 3 * ct.collation.length + ) + ) + }) + } + + // TODO: Add more tests for other string expressions + +} From 8d75846af7b086d6f71477b751aed4a8f3dcac1c Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 14 Mar 2024 15:04:35 +0100 Subject: [PATCH 13/26] Collation test fix --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 72e72a53c4f6..635ec600f1ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -122,7 +122,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "paramIndex" -> "first", "inputSql" -> "\"1\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRING\""), + "requiredType" -> "\"STRING_ANY_COLLATION\""), context = ExpectedContext( fragment = s"collate(1, 'UTF8_BINARY')", start = 7, stop = 31)) } From 63caef65c85884931c26e49dd50f53907a7bd7ce Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 15 Mar 2024 09:33:16 +0100 Subject: [PATCH 14/26] Reformat --- .../apache/spark/sql/types/StringType.scala | 2 -- .../sql/catalyst/analysis/TypeCoercion.scala | 18 ++++++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 606f5be54a16..e23384a8a6be 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -71,8 +71,6 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa @Stable case object StringType extends StringType(0) { def apply(collationId: Int): StringType = new StringType(collationId) - override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isDefaultCollation } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b27ef737946b..d101f30f41ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -703,10 +703,6 @@ abstract class TypeCoercionBase { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { - case (expr: Expression, StringType) if expr.dataType.isInstanceOf[StringType] => expr - case (expr: Expression, StringTypeBinary | StringTypeBinaryLcase | StringTypeAnyCollation) - if !expr.dataType.isInstanceOf[StringType] => - implicitCast(expr, StringType).getOrElse(expr) // If we cannot do the implicit cast, just use the original input. case (in, expected) => implicitCast(in, expected).getOrElse(in) } @@ -960,9 +956,19 @@ object TypeCoercion extends TypeCoercionBase { }) } + @tailrec override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { - implicitCast(e.dataType, expectedType).map { dt => - if (dt == e.dataType) e else Cast(e, dt) + e match { + // Handle special cases for collation types + case expr: Expression if expr.dataType.isInstanceOf[StringType] && + expectedType == StringType => Some(expr) + case expr: Expression if !expr.dataType.isInstanceOf[StringType] && + (expectedType == StringTypeBinary || expectedType == StringTypeBinaryLcase || + expectedType == StringTypeAnyCollation) => implicitCast(expr, StringType) + // Otherwise attempt regular implicit cast + case _ => implicitCast(e.dataType, expectedType).map { dt => + if (dt == e.dataType) e else Cast(e, dt) + } } } From 38c39067e5d079afe335a3ceb251c5dce752998d Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 15 Mar 2024 09:40:35 +0100 Subject: [PATCH 15/26] Remove extra changes --- .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index d101f30f41ee..8df7874ef50b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -702,9 +702,9 @@ abstract class TypeCoercionBase { }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => - val children: Seq[Expression] = e.children.zip(e.inputTypes).map { + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. - case (in, expected) => implicitCast(in, expected).getOrElse(in) + implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) From 1d60cefdac09a5e646eb7e9acb02054085bfb7cc Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 15 Mar 2024 09:45:21 +0100 Subject: [PATCH 16/26] Rewrite collation mismatch check --- .../CollationTypeConstraints.scala | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala index 99a5a1eee60c..a286e21618b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala @@ -27,23 +27,17 @@ object CollationTypeConstraints { def checkCollationCompatibility(collationId: Int, dataTypes: Seq[DataType]): TypeCheckResult = { val collationName = CollationFactory.fetchCollation(collationId).collationName // Additional check needed for collation compatibility - for (dataType <- dataTypes) { - dataType match { - case stringType: StringType => - if (stringType.collationId != collationId) { - val collation = CollationFactory.fetchCollation(stringType.collationId) - return DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> collationName, - "collationNameRight" -> collation.collationName - ) - ) - } - case _ => - } - } - TypeCheckResult.TypeCheckSuccess + dataTypes.collectFirst { + case stringType: StringType if stringType.collationId != collationId => + val collation = CollationFactory.fetchCollation(stringType.collationId) + DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> collationName, + "collationNameRight" -> collation.collationName + ) + ) + } getOrElse TypeCheckResult.TypeCheckSuccess } } From 2f5aba9424ca2ef68f09e70f0d6176e4cd06f8cd Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 15 Mar 2024 11:23:11 +0100 Subject: [PATCH 17/26] Improve lockdown handling --- .../sql/catalyst/util/CollationFactory.java | 2 -- .../apache/spark/sql/types/StringType.scala | 32 +---------------- .../sql/catalyst/analysis/TypeCoercion.scala | 16 +++------ .../CollationTypeConstraints.scala | 36 ++++++++++++++++++- 4 files changed, 40 insertions(+), 46 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 3ecc2fd743cc..2940900b974a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -69,7 +69,6 @@ public static class Collation { * byte for byte equal. All accent or case-insensitive collations are considered non-binary. */ public final boolean isBinaryCollation; - public final boolean isLowercaseCollation; public Collation( String collationName, @@ -84,7 +83,6 @@ public Collation( this.version = version; this.hashFunction = hashFunction; this.isBinaryCollation = isBinaryCollation; - this.isLowercaseCollation = collationName.equals("UTF8_BINARY_LCASE"); if (isBinaryCollation) { this.equalsFunction = UTF8String::equals; diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index e23384a8a6be..2b88f9a01a73 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -40,7 +40,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa * equality and hashing). */ def isBinaryCollation: Boolean = CollationFactory.fetchCollation(collationId).isBinaryCollation - def isLcaseCollation: Boolean = CollationFactory.fetchCollation(collationId).isLowercaseCollation + def isLowercaseCollation: Boolean = collationId == CollationFactory.LOWERCASE_COLLATION_ID /** * Type name that is shown to the customer. @@ -72,33 +72,3 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa case object StringType extends StringType(0) { def apply(collationId: Int): StringType = new StringType(collationId) } - -/** - * Use StringTypeBinary for expressions supporting only binary collation. - */ -case object StringTypeBinary extends AbstractDataType { - override private[sql] def defaultConcreteType: DataType = StringType - override private[sql] def simpleString: String = "string_binary" - override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isBinaryCollation -} - -/** - * Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation. - */ -case object StringTypeBinaryLcase extends AbstractDataType { - override private[sql] def defaultConcreteType: DataType = StringType - override private[sql] def simpleString: String = "string_binary_lcase" - override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].isBinaryCollation || - other.asInstanceOf[StringType].isLcaseCollation) -} - -/** - * Use StringTypeAnyCollation for expressions supporting all possible collation types. - */ -case object StringTypeAnyCollation extends AbstractDataType { - override private[sql] def defaultConcreteType: DataType = StringType - override private[sql] def simpleString: String = "string_any_collation" - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 8df7874ef50b..2f0c57e4b53f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -956,19 +956,9 @@ object TypeCoercion extends TypeCoercionBase { }) } - @tailrec override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { - e match { - // Handle special cases for collation types - case expr: Expression if expr.dataType.isInstanceOf[StringType] && - expectedType == StringType => Some(expr) - case expr: Expression if !expr.dataType.isInstanceOf[StringType] && - (expectedType == StringTypeBinary || expectedType == StringTypeBinaryLcase || - expectedType == StringTypeAnyCollation) => implicitCast(expr, StringType) - // Otherwise attempt regular implicit cast - case _ => implicitCast(e.dataType, expectedType).map { dt => - if (dt == e.dataType) e else Cast(e, dt) - } + implicitCast(e.dataType, expectedType).map { dt => + if (dt == e.dataType) e else Cast(e, dt) } } @@ -1004,8 +994,10 @@ object TypeCoercion extends TypeCoercionBase { case (StringType, datetime: DatetimeType) => datetime case (StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType case (StringType, BinaryType) => BinaryType + case (st: StringType, StringType) => st // Cast any atomic type to string. case (any: AtomicType, StringType) if any != StringType => StringType + case (any: AtomicType, _: StringTypeCollated) if any != StringType => StringType // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala index a286e21618b0..cd909a45c1ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationTypeConstraints.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} object CollationTypeConstraints { @@ -41,3 +41,37 @@ object CollationTypeConstraints { } } + +/** + * StringTypeCollated is an abstract class for StringType with collation support. + */ +abstract class StringTypeCollated extends AbstractDataType { + override private[sql] def defaultConcreteType: DataType = StringType +} + +/** + * Use StringTypeBinary for expressions supporting only binary collation. + */ +case object StringTypeBinary extends StringTypeCollated { + override private[sql] def simpleString: String = "string_binary" + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isBinaryCollation +} + +/** + * Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation. + */ +case object StringTypeBinaryLcase extends StringTypeCollated { + override private[sql] def simpleString: String = "string_binary_lcase" + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].isBinaryCollation || + other.asInstanceOf[StringType].isLowercaseCollation) +} + +/** + * Use StringTypeAnyCollation for expressions supporting all possible collation types. + */ +case object StringTypeAnyCollation extends StringTypeCollated { + override private[sql] def simpleString: String = "string_any_collation" + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] +} From b9b185d392c732f6522d145b97f0e3b7365042c1 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 15 Mar 2024 11:41:03 +0100 Subject: [PATCH 18/26] Add ANSI type coercion --- .../apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 8857f0b5a25e..5aecc46ef4f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -205,6 +205,10 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (StringType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) + // If the StringType has any collation other than UTF8_BINARY, it shouldn't be implicitly + // cast to StringType with collation 0. + case (st: StringType, StringType) => Some(st) + case (DateType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) From 3fed6cbc4bf588dc354e74ae54edf9746ae1585c Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 15 Mar 2024 11:49:44 +0100 Subject: [PATCH 19/26] Better ANSI type coercion explanation --- .../apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 5aecc46ef4f6..a6a95068f059 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -205,8 +205,8 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (StringType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) - // If the StringType has any collation other than UTF8_BINARY, it shouldn't be implicitly - // cast to StringType with collation 0. + // If a function expects StringType, no StringType instance should be implicitly cast to + // StringType with default collation. case (st: StringType, StringType) => Some(st) case (DateType, AnyTimestampType) => From 638c6d467584441fd6cd198b5c0ca54224027e5a Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:02:05 +0100 Subject: [PATCH 20/26] Fix type coercion --- .../spark/sql/catalyst/analysis/AnsiTypeCoercion.scala | 5 +++-- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index a6a95068f059..ce7d08e164cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -205,9 +205,10 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (StringType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) - // If a function expects StringType, no StringType instance should be implicitly cast to - // StringType with default collation. + // If a function expects a StringType, no StringType instance should be implicitly cast to + // StringType with an incompatible collation (aka. lockdown unsupported collations). case (st: StringType, StringType) => Some(st) + case (st: StringType, _: StringTypeCollated) => Some(st) case (DateType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2f0c57e4b53f..afad095771ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -995,6 +995,7 @@ object TypeCoercion extends TypeCoercionBase { case (StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType case (StringType, BinaryType) => BinaryType case (st: StringType, StringType) => st + case (st: StringType, _: StringTypeCollated) => st // Cast any atomic type to string. case (any: AtomicType, StringType) if any != StringType => StringType case (any: AtomicType, _: StringTypeCollated) if any != StringType => StringType From af7611ff6cc4f23430987717daf7f1deeda22c01 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Sun, 17 Mar 2024 21:06:51 +0100 Subject: [PATCH 21/26] Fix ANSI type coercion --- .../spark/sql/catalyst/analysis/AnsiTypeCoercion.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index ce7d08e164cd..5940438728a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -206,9 +206,9 @@ object AnsiTypeCoercion extends TypeCoercionBase { Some(AnyTimestampType.defaultConcreteType) // If a function expects a StringType, no StringType instance should be implicitly cast to - // StringType with an incompatible collation (aka. lockdown unsupported collations). - case (st: StringType, StringType) => Some(st) - case (st: StringType, _: StringTypeCollated) => Some(st) + // StringType with a collation that's not accepted (aka. lockdown unsupported collations). + case (StringType, StringType) => None + case (StringType, _: StringTypeCollated) => None case (DateType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) @@ -220,6 +220,10 @@ object AnsiTypeCoercion extends TypeCoercionBase { None } + // "canANSIStoreAssign" doesn't account for targets extending StringTypeCollated, but + // ANSIStoreAssign is generally expected to return "true" for (AtomicType, StringType) + case (_: AtomicType, _: StringTypeCollated) => Some(StringType) + // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. case (_, TypeCollection(types)) => From 0f4a6b1aea80e3189cf93ca82fde7b75775b007c Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 19 Mar 2024 11:17:07 +0100 Subject: [PATCH 22/26] Fix failing tests --- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 3 ++- .../sql/catalyst/expressions/stringExpressions.scala | 6 +++--- .../test/scala/org/apache/spark/sql/CollationSuite.scala | 8 ++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index afad095771ed..57f15e00ffb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -998,7 +998,8 @@ object TypeCoercion extends TypeCoercionBase { case (st: StringType, _: StringTypeCollated) => st // Cast any atomic type to string. case (any: AtomicType, StringType) if any != StringType => StringType - case (any: AtomicType, _: StringTypeCollated) if any != StringType => StringType + case (any: AtomicType, st: StringTypeCollated) + if any != st.defaultConcreteType => st.defaultConcreteType // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8efc5ad5f7ba..742db0ed5a47 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -427,8 +427,8 @@ trait String2StringExpression extends ImplicitCastInputTypes { def convert(v: UTF8String): UTF8String - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType) + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) protected override def nullSafeEval(input: Any): Any = convert(input.asInstanceOf[UTF8String]) @@ -1965,7 +1965,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringType, BinaryType), IntegerType, IntegerType) + Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType, IntegerType) override def first: Expression = str override def second: Expression = pos diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 543b0067d0a1..3fea89880910 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -633,7 +633,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { s""" |CREATE TABLE testcat.test_table( | c1 STRING COLLATE UNICODE, - | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (c1 || 'a' COLLATE UNICODE) + | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (LOWER(c1)) |) |USING $v2Source |""".stripMargin) @@ -641,7 +641,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", parameters = Map( "fieldName" -> "c2", - "expressionStr" -> "c1 || 'a' COLLATE UNICODE", + "expressionStr" -> "LOWER(c1)", "reason" -> "generation expression cannot contain non-default collated string type")) checkError( @@ -650,7 +650,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { s""" |CREATE TABLE testcat.test_table( | struct1 STRUCT, - | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (SUBSTRING(struct1.a, 0, 1)) + | c2 STRING COLLATE UNICODE GENERATED ALWAYS AS (UCASE(struct1.a)) |) |USING $v2Source |""".stripMargin) @@ -658,7 +658,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", parameters = Map( "fieldName" -> "c2", - "expressionStr" -> "SUBSTRING(struct1.a, 0, 1)", + "expressionStr" -> "UCASE(struct1.a)", "reason" -> "generation expression cannot contain non-default collated string type")) } } From 9b24fba9340e8984060b28c457fd37191f856d89 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 19 Mar 2024 21:07:20 +0100 Subject: [PATCH 23/26] Incorporate requested changes --- .../catalyst/analysis/AnsiTypeCoercion.scala | 18 +- .../sql/catalyst/analysis/TypeCoercion.scala | 6 +- .../CollationRegexpExpressionsANSISuite.scala | 442 ++++++++++++++++++ .../CollationStringExpressionsANSISuite.scala | 77 +++ 4 files changed, 533 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsANSISuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsANSISuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 5940438728a7..ff08a1814201 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -186,6 +186,11 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (NullType, target) if !target.isInstanceOf[TypeCollection] => Some(target.defaultConcreteType) + // If a function expects a StringType, no StringType instance should be implicitly cast to + // StringType with a collation that's not accepted (aka. lockdown unsupported collations). + case (_: StringType, StringType) => None + case (_: StringType, _: StringTypeCollated) => None + // This type coercion system will allow implicit converting String type as other // primitive types, in case of breaking too many existing Spark SQL queries. case (StringType, a: AtomicType) => @@ -205,11 +210,6 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (StringType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) - // If a function expects a StringType, no StringType instance should be implicitly cast to - // StringType with a collation that's not accepted (aka. lockdown unsupported collations). - case (StringType, StringType) => None - case (StringType, _: StringTypeCollated) => None - case (DateType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) @@ -222,7 +222,13 @@ object AnsiTypeCoercion extends TypeCoercionBase { // "canANSIStoreAssign" doesn't account for targets extending StringTypeCollated, but // ANSIStoreAssign is generally expected to return "true" for (AtomicType, StringType) - case (_: AtomicType, _: StringTypeCollated) => Some(StringType) + case (_, st: StringTypeCollated) => + if (Cast.canANSIStoreAssign(inType, st.defaultConcreteType)) { + Some(st.defaultConcreteType) + } + else { + None + } // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 57f15e00ffb4..ecc54976f2db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -994,12 +994,10 @@ object TypeCoercion extends TypeCoercionBase { case (StringType, datetime: DatetimeType) => datetime case (StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType case (StringType, BinaryType) => BinaryType - case (st: StringType, StringType) => st - case (st: StringType, _: StringTypeCollated) => st // Cast any atomic type to string. - case (any: AtomicType, StringType) if any != StringType => StringType + case (any: AtomicType, StringType) if !any.isInstanceOf[StringType] => StringType case (any: AtomicType, st: StringTypeCollated) - if any != st.defaultConcreteType => st.defaultConcreteType + if !any.isInstanceOf[StringType] => st.defaultConcreteType // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsANSISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsANSISuite.scala new file mode 100644 index 000000000000..ad2420533144 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsANSISuite.scala @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class CollationRegexpExpressionsANSISuite extends QueryTest with SharedSparkSession { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_ENABLED, true) + case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) + case class CollationTestFail[R](s1: String, s2: String, collation: String) + + test("Support Like string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "%B%", "UTF8_BINARY", true) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%B%", "UNICODE", true), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + + s"collate('${ct.s2}', '${ct.collation}')") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"collate(${ct.s1}) LIKE collate(${ct.s2})\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"like collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 48 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support ILike string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY", true) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%b%", "UNICODE", true), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + + s"collate('${ct.s2}', '${ct.collation}')") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"ilike(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"ilike collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 49 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RLike string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", ".B.", "UTF8_BINARY", true) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + + s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", ".b.", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", ".B.", "UNICODE", true), + CollationTestCase("ABC", ".b.", "UNICODE_CI", false) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + + s"collate('${ct.s2}', '${ct.collation}')") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"RLIKE(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"rlike collate('${ct.s2}', '${ct.collation}')", + start = 26 + ct.collation.length, + stop = 49 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support StringSplit string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABC", "[B]", "UTF8_BINARY", 2) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABC", "[B]", "UNICODE", 2), + CollationTestCase("ABC", "[b]", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"split(collate(${ct.s1}), collate(${ct.s2}), -1)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"split(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 12, + stop = 55 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpReplace string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE") + ) + checks.foreach(ct => { + checkAnswer( + sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')" + + s",collate('FFF', '${ct.collation}'))"), + Row(ct.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "AFFFE"), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}')" + + s",collate('FFF', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_replace(collate(${ct.s1}), collate(${ct.s2}), collate(FFF), 1)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_replace(collate('${ct.s1}', '${ct.collation}'),collate('${ct.s2}'," + + s" '${ct.collation}'),collate('FFF', '${ct.collation}'))", + start = 7, + stop = 80 + 3 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpExtract string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") + ) + checks.foreach(ct => { + checkAnswer( + sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0)"), + Row(ct.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_extract(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_extract(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'),0)", + start = 7, + stop = 63 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpExtractAll string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) + ) + checks.foreach(ct => { + checkAnswer( + sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'),0))"), + Row(ct.expectedResult) + ) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', " + + s"'${ct.collation}'),collate('${ct.s2}', '${ct.collation}'),0))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_extract_all(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_extract_all(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'),0)", + start = 12, + stop = 72 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpCount string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 1), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_count(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_count(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 59 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpSubStr string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), + CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_substr(collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_substr(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 60 + 2 * ct.collation.length + ) + ) + }) + } + + test("Support RegExpInStr string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 2) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), + CollationTestCase("ABCDE", ".C.", "UNICODE", 2), + CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) + ) + fails.foreach(ct => { + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + + s",collate('${ct.s2}', '${ct.collation}'))") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"regexp_instr(collate(${ct.s1}), collate(${ct.s2}), 0)\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate(${ct.s1})\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"regexp_instr(collate('${ct.s1}', '${ct.collation}')," + + s"collate('${ct.s2}', '${ct.collation}'))", + start = 7, + stop = 59 + 2 * ct.collation.length + ) + ) + }) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsANSISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsANSISuite.scala new file mode 100644 index 000000000000..73568cfcb862 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsANSISuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class CollationStringExpressionsANSISuite extends QueryTest with SharedSparkSession { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_ENABLED, true) + case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) + case class CollationTestFail[R](s1: String, s2: String, collation: String) + + test("Support ConcatWs string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("Spark", "SQL", "UTF8_BINARY", "Spark SQL") + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT concat_ws(collate(' ', '${ct.collation}'), " + + s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))"), + Row(ct.expectedResult)) + }) + // Unsupported collations + val fails = Seq( + CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), + CollationTestCase("ABC", "%B%", "UNICODE", true), + CollationTestCase("ABC", "%b%", "UNICODE_CI", false) + ) + fails.foreach(ct => { + val expr = s"concat_ws(collate(' ', '${ct.collation}'), " + + s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))" + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT $expr") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"concat_ws(collate( ), collate(${ct.s1}), collate(${ct.s2}))\"", + "paramIndex" -> "first", + "inputSql" -> s"\"collate( )\"", + "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", + "requiredType" -> "\"STRING\"" + ), + context = ExpectedContext( + fragment = s"$expr", + start = 7, + stop = 73 + 3 * ct.collation.length + ) + ) + }) + } + + // TODO: Add more tests for other string expressions + +} From 99b36b08d8f6ba821a6ae22ed34148a0180fe5d6 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 20 Mar 2024 11:20:59 +0100 Subject: [PATCH 24/26] Refactor code and fix comments --- .../catalyst/analysis/AnsiTypeCoercion.scala | 2 +- .../CollationRegexpExpressionsANSISuite.scala | 442 ------------------ .../sql/CollationRegexpExpressionsSuite.scala | 6 + .../CollationStringExpressionsANSISuite.scala | 77 --- .../sql/CollationStringExpressionsSuite.scala | 9 +- 5 files changed, 14 insertions(+), 522 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsANSISuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsANSISuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index ff08a1814201..c70d6696ad06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -221,7 +221,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { } // "canANSIStoreAssign" doesn't account for targets extending StringTypeCollated, but - // ANSIStoreAssign is generally expected to return "true" for (AtomicType, StringType) + // ANSIStoreAssign is generally expected to work with StringTypes case (_, st: StringTypeCollated) => if (Cast.canANSIStoreAssign(inType, st.defaultConcreteType)) { Some(st.defaultConcreteType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsANSISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsANSISuite.scala deleted file mode 100644 index ad2420533144..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsANSISuite.scala +++ /dev/null @@ -1,442 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.immutable.Seq - -import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession - -class CollationRegexpExpressionsANSISuite extends QueryTest with SharedSparkSession { - - override protected def sparkConf: SparkConf = - super.sparkConf.set(SQLConf.ANSI_ENABLED, true) - case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) - case class CollationTestFail[R](s1: String, s2: String, collation: String) - - test("Support Like string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABC", "%B%", "UTF8_BINARY", true) - ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + - s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), - CollationTestCase("ABC", "%B%", "UNICODE", true), - CollationTestCase("ABC", "%b%", "UNICODE_CI", false) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT collate('${ct.s1}', '${ct.collation}') like " + - s"collate('${ct.s2}', '${ct.collation}')") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"collate(${ct.s1}) LIKE collate(${ct.s2})\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"like collate('${ct.s2}', '${ct.collation}')", - start = 26 + ct.collation.length, - stop = 48 + 2 * ct.collation.length - ) - ) - }) - } - - test("Support ILike string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABC", "%b%", "UTF8_BINARY", true) - ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + - s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), - CollationTestCase("ABC", "%b%", "UNICODE", true), - CollationTestCase("ABC", "%b%", "UNICODE_CI", false) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT collate('${ct.s1}', '${ct.collation}') ilike " + - s"collate('${ct.s2}', '${ct.collation}')") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"ilike(collate(${ct.s1}), collate(${ct.s2}))\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"ilike collate('${ct.s2}', '${ct.collation}')", - start = 26 + ct.collation.length, - stop = 49 + 2 * ct.collation.length - ) - ) - }) - } - - test("Support RLike string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABC", ".B.", "UTF8_BINARY", true) - ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + - s"collate('${ct.s2}', '${ct.collation}')"), Row(ct.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABC", ".b.", "UTF8_BINARY_LCASE", false), - CollationTestCase("ABC", ".B.", "UNICODE", true), - CollationTestCase("ABC", ".b.", "UNICODE_CI", false) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT collate('${ct.s1}', '${ct.collation}') rlike " + - s"collate('${ct.s2}', '${ct.collation}')") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"RLIKE(collate(${ct.s1}), collate(${ct.s2}))\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"rlike collate('${ct.s2}', '${ct.collation}')", - start = 26 + ct.collation.length, - stop = 49 + 2 * ct.collation.length - ) - ) - }) - } - - test("Support StringSplit string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABC", "[B]", "UTF8_BINARY", 2) - ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}')))"), Row(ct.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", 0), - CollationTestCase("ABC", "[B]", "UNICODE", 2), - CollationTestCase("ABC", "[b]", "UNICODE_CI", 0) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT size(split(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}')))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"split(collate(${ct.s1}), collate(${ct.s2}), -1)\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"split(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'))", - start = 12, - stop = 55 + 2 * ct.collation.length - ) - ) - }) - } - - test("Support RegExpReplace string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE") - ) - checks.foreach(ct => { - checkAnswer( - sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}')" + - s",collate('FFF', '${ct.collation}'))"), - Row(ct.expectedResult) - ) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), - CollationTestCase("ABCDE", ".C.", "UNICODE", "AFFFE"), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT regexp_replace(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}')" + - s",collate('FFF', '${ct.collation}'))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_replace(collate(${ct.s1}), collate(${ct.s2}), collate(FFF), 1)\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_replace(collate('${ct.s1}', '${ct.collation}'),collate('${ct.s2}'," + - s" '${ct.collation}'),collate('FFF', '${ct.collation}'))", - start = 7, - stop = 80 + 3 * ct.collation.length - ) - ) - }) - } - - test("Support RegExpExtract string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") - ) - checks.foreach(ct => { - checkAnswer( - sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'),0)"), - Row(ct.expectedResult) - ) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), - CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT regexp_extract(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'),0)") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_extract(collate(${ct.s1}), collate(${ct.s2}), 0)\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_extract(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'),0)", - start = 7, - stop = 63 + 2 * ct.collation.length - ) - ) - }) - } - - test("Support RegExpExtractAll string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) - ) - checks.foreach(ct => { - checkAnswer( - sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'),0))"), - Row(ct.expectedResult) - ) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), - CollationTestCase("ABCDE", ".C.", "UNICODE", 1), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT size(regexp_extract_all(collate('${ct.s1}', " + - s"'${ct.collation}'),collate('${ct.s2}', '${ct.collation}'),0))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_extract_all(collate(${ct.s1}), collate(${ct.s2}), 0)\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_extract_all(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'),0)", - start = 12, - stop = 72 + 2 * ct.collation.length - ) - ) - }) - } - - test("Support RegExpCount string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1) - ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), - CollationTestCase("ABCDE", ".C.", "UNICODE", 1), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT regexp_count(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_count(collate(${ct.s1}), collate(${ct.s2}))\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_count(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'))", - start = 7, - stop = 59 + 2 * ct.collation.length - ) - ) - }) - } - - test("Support RegExpSubStr string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD") - ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", ""), - CollationTestCase("ABCDE", ".C.", "UNICODE", "BCD"), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", "") - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT regexp_substr(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_substr(collate(${ct.s1}), collate(${ct.s2}))\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_substr(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'))", - start = 7, - stop = 60 + 2 * ct.collation.length - ) - ) - }) - } - - test("Support RegExpInStr string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 2) - ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))"), Row(ct.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABCDE", ".c.", "UTF8_BINARY_LCASE", 0), - CollationTestCase("ABCDE", ".C.", "UNICODE", 2), - CollationTestCase("ABCDE", ".c.", "UNICODE_CI", 0) - ) - fails.foreach(ct => { - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT regexp_instr(collate('${ct.s1}', '${ct.collation}')" + - s",collate('${ct.s2}', '${ct.collation}'))") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"regexp_instr(collate(${ct.s1}), collate(${ct.s2}), 0)\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate(${ct.s1})\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"regexp_instr(collate('${ct.s1}', '${ct.collation}')," + - s"collate('${ct.s2}', '${ct.collation}'))", - start = 7, - stop = 59 + 2 * ct.collation.length - ) - ) - }) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 10127ef38c33..9a8ffb6efa6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql import scala.collection.immutable.Seq +import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession { @@ -434,5 +436,9 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession ) }) } +} +class CollationRegexpExpressionsANSISuite extends CollationRegexpExpressionsSuite { + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_ENABLED, true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsANSISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsANSISuite.scala deleted file mode 100644 index 73568cfcb862..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsANSISuite.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.immutable.Seq - -import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession - -class CollationStringExpressionsANSISuite extends QueryTest with SharedSparkSession { - - override protected def sparkConf: SparkConf = - super.sparkConf.set(SQLConf.ANSI_ENABLED, true) - case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) - case class CollationTestFail[R](s1: String, s2: String, collation: String) - - test("Support ConcatWs string expression with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("Spark", "SQL", "UTF8_BINARY", "Spark SQL") - ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT concat_ws(collate(' ', '${ct.collation}'), " + - s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))"), - Row(ct.expectedResult)) - }) - // Unsupported collations - val fails = Seq( - CollationTestCase("ABC", "%b%", "UTF8_BINARY_LCASE", false), - CollationTestCase("ABC", "%B%", "UNICODE", true), - CollationTestCase("ABC", "%b%", "UNICODE_CI", false) - ) - fails.foreach(ct => { - val expr = s"concat_ws(collate(' ', '${ct.collation}'), " + - s"collate('${ct.s1}', '${ct.collation}'), collate('${ct.s2}', '${ct.collation}'))" - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT $expr") - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"concat_ws(collate( ), collate(${ct.s1}), collate(${ct.s2}))\"", - "paramIndex" -> "first", - "inputSql" -> s"\"collate( )\"", - "inputType" -> s"\"STRING COLLATE ${ct.collation}\"", - "requiredType" -> "\"STRING\"" - ), - context = ExpectedContext( - fragment = s"$expr", - start = 7, - stop = 73 + 3 * ct.collation.length - ) - ) - }) - } - - // TODO: Add more tests for other string expressions - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index b20aceb11b90..9d4a11acff29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql import scala.collection.immutable.Seq +import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession { @@ -67,7 +69,10 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession ) }) } - // TODO: Add more tests for other string expressions - } + +class CollationStringExpressionsANSISuite extends CollationRegexpExpressionsSuite { + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_ENABLED, true) +} \ No newline at end of file From 90519233f47a2fdc07f5ec1bab1dd6cfb4810b4c Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 20 Mar 2024 11:22:31 +0100 Subject: [PATCH 25/26] Fix newlines --- .../org/apache/spark/sql/CollationRegexpExpressionsSuite.scala | 2 +- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 9a8ffb6efa6b..6ab85d66a583 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -441,4 +441,4 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession class CollationRegexpExpressionsANSISuite extends CollationRegexpExpressionsSuite { override protected def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_ENABLED, true) -} +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 9d4a11acff29..ddc11f21c5e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -69,7 +69,9 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession ) }) } + // TODO: Add more tests for other string expressions + } class CollationStringExpressionsANSISuite extends CollationRegexpExpressionsSuite { From 6986b8b144fc590361e38ed0a54d2b2537f402dd Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 20 Mar 2024 11:36:53 +0100 Subject: [PATCH 26/26] Add back newlines --- .../org/apache/spark/sql/CollationRegexpExpressionsSuite.scala | 2 +- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 6ab85d66a583..9a8ffb6efa6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -441,4 +441,4 @@ class CollationRegexpExpressionsSuite extends QueryTest with SharedSparkSession class CollationRegexpExpressionsANSISuite extends CollationRegexpExpressionsSuite { override protected def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_ENABLED, true) -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index ddc11f21c5e8..04f3781a92cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -77,4 +77,4 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession class CollationStringExpressionsANSISuite extends CollationRegexpExpressionsSuite { override protected def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_ENABLED, true) -} \ No newline at end of file +}