diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 5eebec7f1301..a5bb1fe715bb 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -206,21 +206,22 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } public static class Upper { - public static UTF8String exec(final UTF8String v, final int collationId) { + public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { - return execBinary(v); + return useICU ? execBinaryICU(v) : execBinary(v); } else if (collation.supportsLowercaseEquality) { return execLowercase(v); } else { return execICU(v, collationId); } } - public static String genCode(final String v, final int collationId) { + public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Upper.exec"; if (collation.supportsBinaryEquality) { - return String.format(expr + "Binary(%s)", v); + String funcName = useICU ? "BinaryICU" : "Binary"; + return String.format(expr + "%s(%s)", funcName, v); } else if (collation.supportsLowercaseEquality) { return String.format(expr + "Lowercase(%s)", v); } else { @@ -230,6 +231,9 @@ public static String genCode(final String v, final int collationId) { public static UTF8String execBinary(final UTF8String v) { return v.toUpperCase(); } + public static UTF8String execBinaryICU(final UTF8String v) { + return CollationAwareUTF8String.toUpperCase(v); + } public static UTF8String execLowercase(final UTF8String v) { return CollationAwareUTF8String.toUpperCase(v); } @@ -239,21 +243,22 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { } public static class Lower { - public static UTF8String exec(final UTF8String v, final int collationId) { + public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { - return execBinary(v); + return useICU ? execBinaryICU(v) : execBinary(v); } else if (collation.supportsLowercaseEquality) { return execLowercase(v); } else { return execICU(v, collationId); } } - public static String genCode(final String v, final int collationId) { + public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - String expr = "CollationSupport.Lower.exec"; + String expr = "CollationSupport.Lower.exec"; if (collation.supportsBinaryEquality) { - return String.format(expr + "Binary(%s)", v); + String funcName = useICU ? "BinaryICU" : "Binary"; + return String.format(expr + "%s(%s)", funcName, v); } else if (collation.supportsLowercaseEquality) { return String.format(expr + "Lowercase(%s)", v); } else { @@ -263,6 +268,9 @@ public static String genCode(final String v, final int collationId) { public static UTF8String execBinary(final UTF8String v) { return v.toLowerCase(); } + public static UTF8String execBinaryICU(final UTF8String v) { + return CollationAwareUTF8String.toLowerCase(v); + } public static UTF8String execLowercase(final UTF8String v) { return CollationAwareUTF8String.toLowerCase(v); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 58826005fc46..99f35ef81dc6 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -606,7 +606,11 @@ private void assertUpper(String target, String collationName, String expected) UTF8String target_utf8 = UTF8String.fromString(target); UTF8String expected_utf8 = UTF8String.fromString(expected); int collationId = CollationFactory.collationNameToId(collationName); - assertEquals(expected_utf8, CollationSupport.Upper.exec(target_utf8, collationId)); + // Testing the new ICU-based implementation of the Upper function. + assertEquals(expected_utf8, CollationSupport.Upper.exec(target_utf8, collationId, true)); + // Testing the old JVM-based implementation of the Upper function. + assertEquals(expected_utf8, CollationSupport.Upper.exec(target_utf8, collationId, false)); + // Note: results should be the same in these tests for both ICU and JVM-based implementations. } @Test @@ -660,7 +664,11 @@ private void assertLower(String target, String collationName, String expected) UTF8String target_utf8 = UTF8String.fromString(target); UTF8String expected_utf8 = UTF8String.fromString(expected); int collationId = CollationFactory.collationNameToId(collationName); - assertEquals(expected_utf8, CollationSupport.Lower.exec(target_utf8, collationId)); + // Testing the new ICU-based implementation of the Lower function. + assertEquals(expected_utf8, CollationSupport.Lower.exec(target_utf8, collationId, true)); + // Testing the old JVM-based implementation of the Lower function. + assertEquals(expected_utf8, CollationSupport.Lower.exec(target_utf8, collationId, false)); + // Note: results should be the same in these tests for both ICU and JVM-based implementations. } @Test diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ac23962f41ed..e0a9d6f77edd 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -455,12 +455,16 @@ case class Upper(child: Expression) final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId - override def convert(v: UTF8String): UTF8String = CollationSupport.Upper.exec(v, collationId) + // Flag to indicate whether to use ICU instead of JVM case mappings for UTF8_BINARY collation. + private final lazy val useICU = SQLConf.get.getConf(SQLConf.ICU_CASE_MAPPINGS_ENABLED) + + override def convert(v: UTF8String): UTF8String = + CollationSupport.Upper.exec(v, collationId, useICU) final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => CollationSupport.Upper.genCode(c, collationId)) + defineCodeGen(ctx, ev, c => CollationSupport.Upper.genCode(c, collationId, useICU)) } override protected def withNewChildInternal(newChild: Expression): Upper = copy(child = newChild) @@ -483,12 +487,16 @@ case class Lower(child: Expression) final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId - override def convert(v: UTF8String): UTF8String = CollationSupport.Lower.exec(v, collationId) + // Flag to indicate whether to use ICU instead of JVM case mappings for UTF8_BINARY collation. + private final lazy val useICU = SQLConf.get.getConf(SQLConf.ICU_CASE_MAPPINGS_ENABLED) + + override def convert(v: UTF8String): UTF8String = + CollationSupport.Lower.exec(v, collationId, useICU) final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => CollationSupport.Lower.genCode(c, collationId)) + defineCodeGen(ctx, ev, c => CollationSupport.Lower.genCode(c, collationId, useICU)) } override def prettyName: String = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fd804bc0e986..799e54aaecea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -785,6 +785,14 @@ object SQLConf { _ => Map()) .createWithDefault("UTF8_BINARY") + val ICU_CASE_MAPPINGS_ENABLED = + buildConf("spark.sql.icu.caseMappings.enabled") + .doc("When enabled we use the ICU library (instead of the JVM) to implement case mappings" + + " for strings under UTF8_BINARY collation.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val FETCH_SHUFFLE_BLOCKS_IN_BATCH = buildConf("spark.sql.adaptive.fetchShuffleBlocksInBatch") .internal()