Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,19 @@ private static int compareLowerCaseSlow(final UTF8String left, final UTF8String
return lowerCaseCodePoints(left).binaryCompare(lowerCaseCodePoints(right));
}

/*
* Performs string replacement for ICU collations by searching for instances of the search
* string in the src string, with respect to the specified collation, and then replacing
* them with the replace string. The method returns a new UTF8String with all instances of the
* search string replaced using the replace string. Similar to UTF8String.findInSet behaviour
* used for UTF8_BINARY collation, the method returns src string if the search string is empty.
*
* @param src the string to be searched in
* @param search the string to be searched for
* @param replace the string to be used as replacement
* @param collationId the collation ID to use for string search
* @return the position of the first occurrence of `match` in `set`
*/
public static UTF8String replace(final UTF8String src, final UTF8String search,
final UTF8String replace, final int collationId) {
// This collation aware implementation is based on existing implementation on UTF8String
Expand Down Expand Up @@ -286,49 +299,47 @@ public static UTF8String replace(final UTF8String src, final UTF8String search,
return buf.build();
}

/*
* Performs string replacement for UTF8_LCASE collation by searching for instances of the search
* string in the src string, with respect to lowercased string versions, and then replacing
* them with the replace string. The method returns a new UTF8String with all instances of the
* search string replaced using the replace string. Similar to UTF8String.findInSet behaviour
* used for UTF8_BINARY collation, the method returns src string if the search string is empty.
*
* @param src the string to be searched in
* @param search the string to be searched for
* @param replace the string to be used as replacement
* @param collationId the collation ID to use for string search
* @return the position of the first occurrence of `match` in `set`
*/
public static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search,
final UTF8String replace) {
if (src.numBytes() == 0 || search.numBytes() == 0) {
return src;
}
UTF8String lowercaseString = src.toLowerCase();

// TODO(SPARK-48725): Use lowerCaseCodePoints instead of UTF8String.toLowerCase.
UTF8String lowercaseSearch = search.toLowerCase();

int start = 0;
int end = lowercaseString.indexOf(lowercaseSearch, 0);
int end = lowercaseFind(src, lowercaseSearch, start);
if (end == -1) {
// Search string was not found, so string is unchanged.
return src;
}

// Initialize byte positions
int c = 0;
int byteStart = 0; // position in byte
int byteEnd = 0; // position in byte
while (byteEnd < src.numBytes() && c < end) {
byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
c += 1;
}

// At least one match was found. Estimate space needed for result.
// The 16x multiplier here is chosen to match commons-lang3's implementation.
int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16;
final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase);
while (end != -1) {
buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart);
buf.append(src.substring(start, end));
buf.append(replace);
// Update character positions
start = end + lowercaseSearch.numChars();
end = lowercaseString.indexOf(lowercaseSearch, start);
// Update byte positions
byteStart = byteEnd + search.numBytes();
while (byteEnd < src.numBytes() && c < end) {
byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
c += 1;
}
start = end + lowercaseMatchLengthFrom(src, lowercaseSearch, end);
end = lowercaseFind(src, lowercaseSearch, start);
}
buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
src.numBytes() - byteStart);
buf.append(src.substring(start, src.numChars()));
return buf.build();
}

Expand Down Expand Up @@ -479,34 +490,40 @@ public static UTF8String toTitleCase(final UTF8String target, final int collatio
BreakIterator.getWordInstance(locale)));
}

/*
* Returns the position of the first occurrence of the match string in the set string,
* counting ASCII commas as delimiters. The match string is compared in a collation-aware manner,
* with respect to the specified collation ID. Similar to UTF8String.findInSet behaviour used
* for UTF8_BINARY collation, the method returns 0 if the match string contains no commas.
*
* @param match the string to be searched for
* @param set the string to be searched in
* @param collationId the collation ID to use for string comparison
* @return the position of the first occurrence of `match` in `set`
*/
public static int findInSet(final UTF8String match, final UTF8String set, int collationId) {
// If the "word" string contains a comma, FindInSet should return 0.
if (match.contains(UTF8String.fromString(","))) {
return 0;
}

// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
String setString = set.toString();
StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(),
collationId);

int wordStart = 0;
while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ',';
boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length()
|| setString.charAt(wordStart + stringSearch.getMatchLength()) == ',';

if (isValidStart && isValidEnd) {
int pos = 0;
for (int i = 0; i < setString.length() && i < wordStart; i++) {
if (setString.charAt(i) == ',') {
pos++;
}
// Otherwise, search for commas in "set" and compare each substring with "word".
int byteIndex = 0, charIndex = 0, wordCount = 1, lastComma = -1;
while (byteIndex < set.numBytes()) {
byte nextByte = set.getByte(byteIndex);
if (nextByte == (byte) ',') {
if (set.substring(lastComma + 1, charIndex).semanticEquals(match, collationId)) {
return wordCount;
}

return pos + 1;
lastComma = charIndex;
++wordCount;
}
byteIndex += UTF8String.numBytesForFirstByte(nextByte);
++charIndex;
}

if (set.substring(lastComma + 1, set.numBytes()).semanticEquals(match, collationId)) {
return wordCount;
}
// If no match is found, return 0.
return 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,31 +318,24 @@ public static int exec(final UTF8String word, final UTF8String set, final int co
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(word, set);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(word, set);
} else {
return execICU(word, set, collationId);
return execCollationAware(word, set, collationId);
}
}
public static String genCode(final String word, final String set, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.FindInSet.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s)", word, set);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s)", word, set);
} else {
return String.format(expr + "ICU(%s, %s, %d)", word, set, collationId);
return String.format(expr + "execCollationAware(%s, %s, %d)", word, set, collationId);
}
}
public static int execBinary(final UTF8String word, final UTF8String set) {
return set.findInSet(word);
}
public static int execLowercase(final UTF8String word, final UTF8String set) {
return set.toLowerCase().findInSet(word.toLowerCase());
}
public static int execICU(final UTF8String word, final UTF8String set,
final int collationId) {
public static int execCollationAware(final UTF8String word, final UTF8String set,
final int collationId) {
return CollationAwareUTF8String.findInSet(word, set, collationId);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -875,47 +875,105 @@ public void testStringInstr() throws SparkException {
assertStringInstr("aİoi̇oxx", "XX", "UTF8_LCASE", 7);
}

private void assertFindInSet(String word, String set, String collationName,
Integer expected) throws SparkException {
private void assertFindInSet(String word, UTF8String set, String collationName,
Integer expected) throws SparkException {
UTF8String w = UTF8String.fromString(word);
UTF8String s = UTF8String.fromString(set);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected, CollationSupport.FindInSet.exec(w, s, collationId));
assertEquals(expected, CollationSupport.FindInSet.exec(w, set, collationId));
}

@Test
public void testFindInSet() throws SparkException {
assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1);
assertFindInSet("def", "abc,b,ab,c,def", "UTF8_BINARY", 5);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("a", "abc,b,ab,c,def", "UTF8_LCASE", 0);
assertFindInSet("c", "abc,b,ab,c,def", "UTF8_LCASE", 4);
assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_LCASE", 3);
assertFindInSet("AbC", "abc,b,ab,c,def", "UTF8_LCASE", 1);
assertFindInSet("abcd", "abc,b,ab,c,def", "UTF8_LCASE", 0);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_LCASE", 0);
assertFindInSet("XX", "xx", "UTF8_LCASE", 1);
assertFindInSet("", "abc,b,ab,c,def", "UTF8_LCASE", 0);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UTF8_LCASE", 4);
assertFindInSet("a", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("ab", "abc,b,ab,c,def", "UNICODE", 3);
assertFindInSet("Ab", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("xx", "xx", "UNICODE", 1);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE", 0);
assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE", 5);
assertFindInSet("a", "abc,b,ab,c,def", "UNICODE_CI", 0);
assertFindInSet("C", "abc,b,ab,c,def", "UNICODE_CI", 4);
assertFindInSet("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5);
assertFindInSet("DEFG", "abc,b,ab,c,def", "UNICODE_CI", 0);
assertFindInSet("XX", "xx", "UNICODE_CI", 1);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 4);
assertFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", "UNICODE_CI", 5);
assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 5);
assertFindInSet("i̇o", "ab,İo,12", "UNICODE_CI", 2);
assertFindInSet("İo", "ab,i̇o,12", "UNICODE_CI", 2);
assertFindInSet("AB", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0);
assertFindInSet("abc", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 1);
assertFindInSet("def", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 5);
assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0);
assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UTF8_BINARY", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UTF8_BINARY", 6);
assertFindInSet("", UTF8String.fromString("abc"), "UTF8_BINARY", 0);
assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0);
assertFindInSet("c", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 4);
assertFindInSet("AB", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 3);
assertFindInSet("AbC", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 1);
assertFindInSet("abcd", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0);
assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0);
assertFindInSet("XX", UTF8String.fromString("xx"), "UTF8_LCASE", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0);
assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UTF8_LCASE", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UTF8_LCASE", 6);
assertFindInSet("", UTF8String.fromString("abc"), "UTF8_LCASE", 0);
assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UTF8_LCASE", 4);
assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0);
assertFindInSet("ab", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 3);
assertFindInSet("Ab", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0);
assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0);
assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UNICODE", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UNICODE", 6);
assertFindInSet("", UTF8String.fromString("abc"), "UNICODE", 0);
assertFindInSet("xx", UTF8String.fromString("xx"), "UNICODE", 1);
assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE", 0);
assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE", 5);
assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0);
assertFindInSet("C", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 4);
assertFindInSet("DeF", UTF8String.fromString("abc,b,ab,c,dEf"), "UNICODE_CI", 5);
assertFindInSet("DEFG", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0);
assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UNICODE_CI", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UNICODE_CI", 6);
assertFindInSet("", UTF8String.fromString("abc"), "UNICODE_CI", 0);
assertFindInSet("XX", UTF8String.fromString("xx"), "UNICODE_CI", 1);
assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE_CI", 4);
assertFindInSet("界x", UTF8String.fromString("test,大千,界Xx,世,界X,大,千,世界"), "UNICODE_CI", 5);
assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE_CI", 5);
assertFindInSet("i̇", UTF8String.fromString("İ"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("İ"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("i̇"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("İ,"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("İ,"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇,"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("i̇,"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ,12"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇,12"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇,12"), "UNICODE_CI", 0);
assertFindInSet("i̇o", UTF8String.fromString("ab,İo,12"), "UNICODE_CI", 2);
assertFindInSet("İo", UTF8String.fromString("ab,i̇o,12"), "UNICODE_CI", 2);
assertFindInSet("i̇", UTF8String.fromString("İ"), "UTF8_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("İ"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇"), "UTF8_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("i̇"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("İ,"), "UTF8_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("İ,"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇,"), "UTF8_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("i̇,"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ"), "UTF8_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇"), "UTF8_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ,12"), "UTF8_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇,12"), "UTF8_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇,12"), "UTF8_LCASE", 0);
assertFindInSet("i̇o", UTF8String.fromString("ab,İo,12"), "UTF8_LCASE", 2);
assertFindInSet("İo", UTF8String.fromString("ab,i̇o,12"), "UTF8_LCASE", 2);
// Invalid UTF8 strings
assertFindInSet("C", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UTF8_BINARY", 3);
assertFindInSet("c", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UTF8_LCASE", 2);
assertFindInSet("C", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UNICODE", 2);
assertFindInSet("c", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UNICODE_CI", 2);
}

private void assertReplace(String source, String search, String replace, String collationName,
Expand Down Expand Up @@ -952,8 +1010,23 @@ public void testReplace() throws SparkException {
assertReplace("replace", "", "123", "UNICODE_CI", "replace");
assertReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c");
assertReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad");
assertReplace("abi̇12", "i", "X", "UNICODE_CI", "abi̇12");
assertReplace("abi̇12", "\u0307", "X", "UNICODE_CI", "abi̇12");
assertReplace("abi̇12", "İ", "X", "UNICODE_CI", "abX12");
assertReplace("abİ12", "i", "X", "UNICODE_CI", "abİ12");
assertReplace("İi̇İi̇İi̇", "i̇", "x", "UNICODE_CI", "xxxxxx");
assertReplace("İi̇İi̇İi̇", "i", "x", "UNICODE_CI", "İi̇İi̇İi̇");
assertReplace("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx");
assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy");
assertReplace("abi̇12", "i", "X", "UTF8_LCASE", "abX\u030712"); // != UNICODE_CI
assertReplace("abi̇12", "\u0307", "X", "UTF8_LCASE", "abiX12"); // != UNICODE_CI
assertReplace("abi̇12", "İ", "X", "UTF8_LCASE", "abX12");
assertReplace("abİ12", "i", "X", "UTF8_LCASE", "abİ12");
assertReplace("İi̇İi̇İi̇", "i̇", "x", "UTF8_LCASE", "xxxxxx");
assertReplace("İi̇İi̇İi̇", "i", "x", "UTF8_LCASE",
"İx\u0307İx\u0307İx\u0307"); // != UNICODE_CI
assertReplace("abİo12i̇o", "i̇o", "xx", "UTF8_LCASE", "abxx12xx");
assertReplace("abi̇o12i̇o", "İo", "yy", "UTF8_LCASE", "abyy12yy");
}

private void assertLocate(String substring, String string, Integer start, String collationName,
Expand Down