diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 934572cd0d67..7c0ffc73aa9d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -228,7 +228,7 @@ private static int compareLowerCaseAscii(final UTF8String left, final UTF8String * @return An integer representing the comparison result. */ private static int compareLowerCaseSlow(final UTF8String left, final UTF8String right) { - return lowerCaseCodePoints(left.toString()).compareTo(lowerCaseCodePoints(right.toString())); + return lowerCaseCodePoints(left).binaryCompare(lowerCaseCodePoints(right)); } public static UTF8String replace(final UTF8String src, final UTF8String search, @@ -339,11 +339,15 @@ public static UTF8String lowercaseReplace(final UTF8String src, final UTF8String * @return the uppercase string */ public static UTF8String toUpperCase(final UTF8String target) { - return UTF8String.fromString(toUpperCase(target.toString())); + if (target.isFullAscii()) return target.toUpperCaseAscii(); + return toUpperCaseSlow(target); } - public static String toUpperCase(final String target) { - return UCharacter.toUpperCase(target); + private static UTF8String toUpperCaseSlow(final UTF8String target) { + // Note: In order to achieve the desired behaviour, we use the ICU UCharacter class to + // convert the string to uppercase, which only accepts a Java strings as input. + // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` + return UTF8String.fromString(UCharacter.toUpperCase(target.toString())); } /** @@ -353,13 +357,17 @@ public static String toUpperCase(final String target) { * @return the uppercase string */ public static UTF8String toUpperCase(final UTF8String target, final int collationId) { - return UTF8String.fromString(toUpperCase(target.toString(), collationId)); + if (target.isFullAscii()) return target.toUpperCaseAscii(); + return toUpperCaseSlow(target, collationId); } - public static String toUpperCase(final String target, final int collationId) { + private static UTF8String toUpperCaseSlow(final UTF8String target, final int collationId) { + // Note: In order to achieve the desired behaviour, we use the ICU UCharacter class to + // convert the string to uppercase, which only accepts a Java strings as input. ULocale locale = CollationFactory.fetchCollation(collationId) .collator.getLocale(ULocale.ACTUAL_LOCALE); - return UCharacter.toUpperCase(locale, target); + // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` + return UTF8String.fromString(UCharacter.toUpperCase(locale, target.toString())); } /** @@ -369,10 +377,15 @@ public static String toUpperCase(final String target, final int collationId) { * @return the lowercase string */ public static UTF8String toLowerCase(final UTF8String target) { - return UTF8String.fromString(toLowerCase(target.toString())); + if (target.isFullAscii()) return target.toLowerCaseAscii(); + return toLowerCaseSlow(target); } - public static String toLowerCase(final String target) { - return UCharacter.toLowerCase(target); + + private static UTF8String toLowerCaseSlow(final UTF8String target) { + // Note: In order to achieve the desired behaviour, we use the ICU UCharacter class to + // convert the string to lowercase, which only accepts a Java strings as input. + // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` + return UTF8String.fromString(UCharacter.toLowerCase(target.toString())); } /** @@ -382,12 +395,17 @@ public static String toLowerCase(final String target) { * @return the lowercase string */ public static UTF8String toLowerCase(final UTF8String target, final int collationId) { - return UTF8String.fromString(toLowerCase(target.toString(), collationId)); + if (target.isFullAscii()) return target.toLowerCaseAscii(); + return toLowerCaseSlow(target, collationId); } - public static String toLowerCase(final String target, final int collationId) { + + private static UTF8String toLowerCaseSlow(final UTF8String target, final int collationId) { + // Note: In order to achieve the desired behaviour, we use the ICU UCharacter class to + // convert the string to lowercase, which only accepts a Java strings as input. ULocale locale = CollationFactory.fetchCollation(collationId) .collator.getLocale(ULocale.ACTUAL_LOCALE); - return UCharacter.toLowerCase(locale, target); + // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` + return UTF8String.fromString(UCharacter.toLowerCase(locale, target.toString())); } /** @@ -424,36 +442,41 @@ else if (codePoint == 0x03C2) { * @param target The target string to convert to lowercase. * @return The string converted to lowercase in a context-unaware manner. */ - public static String lowerCaseCodePoints(final String target) { + public static UTF8String lowerCaseCodePoints(final UTF8String target) { + if (target.isFullAscii()) return target.toLowerCaseAscii(); + return lowerCaseCodePointsSlow(target); + } + + private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) { + // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` + String targetString = target.toString(); StringBuilder sb = new StringBuilder(); - for (int i = 0; i < target.length(); ++i) { - lowercaseCodePoint(target.codePointAt(i), sb); + for (int i = 0; i < targetString.length(); ++i) { + lowercaseCodePoint(targetString.codePointAt(i), sb); } - return sb.toString(); + return UTF8String.fromString(sb.toString()); } /** * Convert the input string to titlecase using the ICU root locale rules. */ public static UTF8String toTitleCase(final UTF8String target) { - return UTF8String.fromString(toTitleCase(target.toString())); - } - - public static String toTitleCase(final String target) { - return UCharacter.toTitleCase(target, BreakIterator.getWordInstance()); + // Note: In order to achieve the desired behaviour, we use the ICU UCharacter class to + // convert the string to titlecase, which only accepts a Java strings as input. + // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` + return UTF8String.fromString(UCharacter.toTitleCase(target.toString(), + BreakIterator.getWordInstance())); } /** * Convert the input string to titlecase using the specified ICU collation rules. */ public static UTF8String toTitleCase(final UTF8String target, final int collationId) { - return UTF8String.fromString(toTitleCase(target.toString(), collationId)); - } - - public static String toTitleCase(final String target, final int collationId) { ULocale locale = CollationFactory.fetchCollation(collationId) .collator.getLocale(ULocale.ACTUAL_LOCALE); - return UCharacter.toTitleCase(locale, target, BreakIterator.getWordInstance(locale)); + // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` + return UTF8String.fromString(UCharacter.toTitleCase(locale, target.toString(), + BreakIterator.getWordInstance(locale))); } public static int findInSet(final UTF8String match, final UTF8String set, int collationId) { @@ -461,6 +484,7 @@ public static int findInSet(final UTF8String match, final UTF8String set, int co return 0; } + // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` String setString = set.toString(); StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(), collationId); @@ -623,6 +647,7 @@ public static UTF8String lowercaseSubStringIndex(final UTF8String string, public static Map getCollationAwareDict(UTF8String string, Map dict, int collationId) { + // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` String srcStr = string.toString(); Map collationAwareDict = new HashMap<>(); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 61ec6f7da215..b0f6c5c22991 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -299,7 +299,7 @@ private static Collation fetchCollation(int collationId) { == DefinitionOrigin.PREDEFINED); if (collationId == UTF8_BINARY_COLLATION_ID) { // Skip cache. - return CollationSpecUTF8Binary.UTF8_BINARY_COLLATION; + return CollationSpecUTF8.UTF8_BINARY_COLLATION; } else if (collationMap.containsKey(collationId)) { // Already in cache. return collationMap.get(collationId); @@ -308,7 +308,7 @@ private static Collation fetchCollation(int collationId) { CollationSpec spec; ImplementationProvider implementationProvider = getImplementationProvider(collationId); if (implementationProvider == ImplementationProvider.UTF8_BINARY) { - spec = CollationSpecUTF8Binary.fromCollationId(collationId); + spec = CollationSpecUTF8.fromCollationId(collationId); } else { spec = CollationSpecICU.fromCollationId(collationId); } @@ -327,7 +327,7 @@ private static int collationNameToId(String collationName) throws SparkException // Collation names provided by user are treated as case-insensitive. String collationNameUpper = collationName.toUpperCase(); if (collationNameUpper.startsWith("UTF8_")) { - return CollationSpecUTF8Binary.collationNameToId(collationName, collationNameUpper); + return CollationSpecUTF8.collationNameToId(collationName, collationNameUpper); } else { return CollationSpecICU.collationNameToId(collationName, collationNameUpper); } @@ -336,7 +336,7 @@ private static int collationNameToId(String collationName) throws SparkException protected abstract Collation buildCollation(); } - private static class CollationSpecUTF8Binary extends CollationSpec { + private static class CollationSpecUTF8 extends CollationSpec { /** * Bit 0 in collation ID having value 0 for plain UTF8_BINARY and 1 for UTF8_LCASE @@ -357,17 +357,17 @@ private enum CaseSensitivity { private static final int CASE_SENSITIVITY_MASK = 0b1; private static final int UTF8_BINARY_COLLATION_ID = - new CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).collationId; + new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).collationId; private static final int UTF8_LCASE_COLLATION_ID = - new CollationSpecUTF8Binary(CaseSensitivity.LCASE).collationId; + new CollationSpecUTF8(CaseSensitivity.LCASE).collationId; protected static Collation UTF8_BINARY_COLLATION = - new CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).buildCollation(); + new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).buildCollation(); protected static Collation UTF8_LCASE_COLLATION = - new CollationSpecUTF8Binary(CaseSensitivity.LCASE).buildCollation(); + new CollationSpecUTF8(CaseSensitivity.LCASE).buildCollation(); private final int collationId; - private CollationSpecUTF8Binary(CaseSensitivity caseSensitivity) { + private CollationSpecUTF8(CaseSensitivity caseSensitivity) { this.collationId = SpecifierUtils.setSpecValue(0, CASE_SENSITIVITY_OFFSET, caseSensitivity); } @@ -384,14 +384,14 @@ private static int collationNameToId(String originalName, String collationName) } } - private static CollationSpecUTF8Binary fromCollationId(int collationId) { + private static CollationSpecUTF8 fromCollationId(int collationId) { // Extract case sensitivity from collation ID. int caseConversionOrdinal = SpecifierUtils.getSpecValue(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); // Verify only case sensitivity bits were set settable in UTF8_BINARY family of collations. assert (SpecifierUtils.removeSpec(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK) == 0); - return new CollationSpecUTF8Binary(CaseSensitivity.values()[caseConversionOrdinal]); + return new CollationSpecUTF8(CaseSensitivity.values()[caseConversionOrdinal]); } @Override @@ -414,7 +414,7 @@ protected Collation buildCollation() { null, CollationAwareUTF8String::compareLowerCase, "1.0", - s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s.toString()).hashCode(), + s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(), /* supportsBinaryEquality = */ false, /* supportsBinaryOrdering = */ false, /* supportsLowercaseEquality = */ true); @@ -727,9 +727,9 @@ public CollationIdentifier identifier() { public static final List SUPPORTED_PROVIDERS = List.of(PROVIDER_SPARK, PROVIDER_ICU); public static final int UTF8_BINARY_COLLATION_ID = - Collation.CollationSpecUTF8Binary.UTF8_BINARY_COLLATION_ID; + Collation.CollationSpecUTF8.UTF8_BINARY_COLLATION_ID; public static final int UTF8_LCASE_COLLATION_ID = - Collation.CollationSpecUTF8Binary.UTF8_LCASE_COLLATION_ID; + Collation.CollationSpecUTF8.UTF8_LCASE_COLLATION_ID; public static final int UNICODE_COLLATION_ID = Collation.CollationSpecICU.UNICODE_COLLATION_ID; public static final int UNICODE_CI_COLLATION_ID = diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e7f16988c537..12a7b06232ee 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -23,6 +23,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; +import java.util.function.Function; import java.util.Map; import java.util.regex.Pattern; @@ -495,6 +496,18 @@ public boolean endsWith(final UTF8String suffix) { return matchAt(suffix, numBytes - suffix.numBytes); } + /** + * Method for ASCII character conversion using a functional interface for chars. + */ + + private UTF8String convertAscii(Function charConverter) { + byte[] bytes = new byte[numBytes]; + for (int i = 0; i < numBytes; i++) { + bytes[i] = (byte) charConverter.apply((char) getByte(i)).charValue(); + } + return fromBytes(bytes); + } + /** * Returns the upper case of this string */ @@ -502,18 +515,12 @@ public UTF8String toUpperCase() { if (numBytes == 0) { return EMPTY_UTF8; } - // Optimization - do char level uppercase conversion in case of chars in ASCII range - for (int i = 0; i < numBytes; i++) { - if (getByte(i) < 0) { - // non-ASCII - return toUpperCaseSlow(); - } - } - byte[] bytes = new byte[numBytes]; - for (int i = 0; i < numBytes; i++) { - bytes[i] = (byte) Character.toUpperCase(getByte(i)); - } - return fromBytes(bytes); + + return isFullAscii() ? toUpperCaseAscii() : toUpperCaseSlow(); + } + + public UTF8String toUpperCaseAscii() { + return convertAscii(Character::toUpperCase); } private UTF8String toUpperCaseSlow() { @@ -544,12 +551,8 @@ private UTF8String toLowerCaseSlow() { return fromString(toString().toLowerCase()); } - private UTF8String toLowerCaseAscii() { - final var bytes = new byte[numBytes]; - for (var i = 0; i < numBytes; i++) { - bytes[i] = (byte) Character.toLowerCase(getByte(i)); - } - return fromBytes(bytes); + public UTF8String toLowerCaseAscii() { + return convertAscii(Character::toLowerCase); } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 436dff1db0e0..9602c83c6c80 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -41,7 +41,7 @@ public class CollationSupportSuite { */ private void assertStringCompare(String s1, String s2, String collationName, int expected) - throws SparkException { + throws SparkException { UTF8String l = UTF8String.fromString(s1); UTF8String r = UTF8String.fromString(s2); int compare = CollationFactory.fetchCollation(collationName).comparator.compare(l, r); @@ -129,13 +129,26 @@ public void testCompare() throws SparkException { assertStringCompare("ς", "σ", "UNICODE_CI", 0); assertStringCompare("ς", "Σ", "UNICODE_CI", 0); assertStringCompare("σ", "Σ", "UNICODE_CI", 0); + // Maximum code point. + int maxCodePoint = Character.MAX_CODE_POINT; + String maxCodePointStr = new String(Character.toChars(maxCodePoint)); + for (int i = 0; i < maxCodePoint && Character.isValidCodePoint(i); ++i) { + assertStringCompare(new String(Character.toChars(i)), maxCodePointStr, "UTF8_BINARY", -1); + assertStringCompare(new String(Character.toChars(i)), maxCodePointStr, "UTF8_LCASE", -1); + } + // Minimum code point. + int minCodePoint = Character.MIN_CODE_POINT; + String minCodePointStr = new String(Character.toChars(minCodePoint)); + for (int i = minCodePoint + 1; i <= maxCodePoint && Character.isValidCodePoint(i); ++i) { + assertStringCompare(new String(Character.toChars(i)), minCodePointStr, "UTF8_BINARY", 1); + assertStringCompare(new String(Character.toChars(i)), minCodePointStr, "UTF8_LCASE", 1); + } } private void assertLowerCaseCodePoints(UTF8String target, UTF8String expected, Boolean useCodePoints) { if (useCodePoints) { - assertEquals(expected.toString(), - CollationAwareUTF8String.lowerCaseCodePoints(target.toString())); + assertEquals(expected, CollationAwareUTF8String.lowerCaseCodePoints(target)); } else { assertEquals(expected, target.toLowerCase()); }