Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1529,9 +1529,10 @@ public static UTF8String trimRight(
}

public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim,
final int limit, final int collationId) {
final int limit, final int collationId, boolean legacySplitTruncate) {
if (CollationFactory.fetchCollation(collationId).isUtf8BinaryType) {
return input.split(delim, limit);
return legacySplitTruncate ?
input.splitLegacyTruncate(delim, limit) : input.split(delim, limit);
} else if (CollationFactory.fetchCollation(collationId).isUtf8LcaseType) {
return lowercaseSplitSQL(input, delim, limit);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,10 @@ public static int collationAwareRegexFlags(final int collationId) {
public static UTF8String lowercaseRegex(final UTF8String regex) {
return UTF8String.concat(lowercaseRegexPrefix, regex);
}
public static UTF8String collationAwareRegex(final UTF8String regex, final int collationId) {
return supportsLowercaseRegex(collationId) ? lowercaseRegex(regex) : regex;
public static UTF8String collationAwareRegex(
final UTF8String regex, final int collationId, boolean notIgnoreEmpty) {
return supportsLowercaseRegex(collationId) && (notIgnoreEmpty || regex.numBytes() != 0)
? lowercaseRegex(regex) : regex;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,25 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) {
}

public UTF8String[] split(UTF8String pattern, int limit) {
// For the empty `pattern` a `split` function ignores trailing empty strings unless original
// string is empty.
if (numBytes() != 0 && pattern.numBytes() == 0) {
int newLimit = limit > numChars() || limit <= 0 ? numChars() : limit;
byte[] input = getBytes();
int byteIndex = 0;
UTF8String[] result = new UTF8String[newLimit];
for (int charIndex = 0; charIndex < newLimit - 1; charIndex++) {
int currCharNumBytes = numBytesForFirstByte(input[byteIndex]);
result[charIndex] = UTF8String.fromBytes(input, byteIndex, currCharNumBytes);
byteIndex += currCharNumBytes;
}
result[newLimit - 1] = UTF8String.fromBytes(input, byteIndex, numBytes() - byteIndex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is ArrayIndexOutOfBoundsException possible here?
what if newLimit=0 (i.e. numChars()=0, limit=-1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is ArrayIndexOutOfBoundsException possible here? what if newLimit=0 (i.e. numChars()=0, limit=-1)

no, this code block will only be entered when the following conditions are met.

if (numBytes() != 0 && pattern.numBytes() == 0)

return result;
}
return split(pattern.toString(), limit);
}

public UTF8String[] splitLegacyTruncate(UTF8String pattern, int limit) {
// For the empty `pattern` a `split` function ignores trailing empty strings unless original
// string is empty.
if (numBytes() != 0 && pattern.numBytes() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ public void split() {
new UTF8String[]{fromString("a"), fromString("b")},
fromString("ab").split(fromString(""), 100));
assertArrayEquals(
new UTF8String[]{fromString("a")},
new UTF8String[]{fromString("ab")},
fromString("ab").split(fromString(""), 1));
assertArrayEquals(
new UTF8String[]{fromString("")},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,18 +597,21 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E

private final lazy val collationId: Int = text.dataType.asInstanceOf[StringType].collationId

private lazy val legacySplitTruncate =
SQLConf.get.getConf(SQLConf.LEGACY_TRUNCATE_FOR_EMPTY_REGEX_SPLIT)

override def nullSafeEval(
inputString: Any,
stringDelimiter: Any,
keyValueDelimiter: Any): Any = {
val keyValues = CollationAwareUTF8String.splitSQL(inputString.asInstanceOf[UTF8String],
stringDelimiter.asInstanceOf[UTF8String], -1, collationId)
stringDelimiter.asInstanceOf[UTF8String], -1, collationId, legacySplitTruncate)
val keyValueDelimiterUTF8String = keyValueDelimiter.asInstanceOf[UTF8String]

var i = 0
while (i < keyValues.length) {
val keyValueArray = CollationAwareUTF8String.splitSQL(
keyValues(i), keyValueDelimiterUTF8String, 2, collationId)
keyValues(i), keyValueDelimiterUTF8String, 2, collationId, legacySplitTruncate)
val key = keyValueArray(0)
val value = if (keyValueArray.length < 2) null else keyValueArray(1)
mapBuilder.put(key, value)
Expand All @@ -623,9 +626,11 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E

nullSafeCodeGen(ctx, ev, (text, pd, kvd) =>
s"""
|UTF8String[] $keyValues = CollationAwareUTF8String.splitSQL($text, $pd, -1, $collationId);
|UTF8String[] $keyValues =
| CollationAwareUTF8String.splitSQL($text, $pd, -1, $collationId, $legacySplitTruncate);
|for(UTF8String kvEntry: $keyValues) {
| UTF8String[] kv = CollationAwareUTF8String.splitSQL(kvEntry, $kvd, 2, $collationId);
| UTF8String[] kv = CollationAwareUTF8String.splitSQL(
| kvEntry, $kvd, 2, $collationId, $legacySplitTruncate);
| $builderTerm.put(kv[0], kv.length == 2 ? kv[1] : null);
|}
|${ev.value} = $builderTerm.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern}
import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.types.{
StringTypeBinaryLcase, StringTypeWithCollation}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{StringTypeBinaryLcase, StringTypeWithCollation}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -580,20 +580,33 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression)

final lazy val collationId: Int = str.dataType.asInstanceOf[StringType].collationId

private lazy val legacySplitTruncate =
SQLConf.get.getConf(SQLConf.LEGACY_TRUNCATE_FOR_EMPTY_REGEX_SPLIT)

def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1))

override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = {
val pattern = CollationSupport.collationAwareRegex(regex.asInstanceOf[UTF8String], collationId)
val strings = string.asInstanceOf[UTF8String].split(pattern, limit.asInstanceOf[Int])
val pattern = CollationSupport.collationAwareRegex(
regex.asInstanceOf[UTF8String], collationId, legacySplitTruncate)
val strings = if (legacySplitTruncate) {
string.asInstanceOf[UTF8String].splitLegacyTruncate(pattern, limit.asInstanceOf[Int])
} else {
string.asInstanceOf[UTF8String].split(pattern, limit.asInstanceOf[Int])
}
new GenericArrayData(strings.asInstanceOf[Array[Any]])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayClass = classOf[GenericArrayData].getName
val pattern = ctx.freshName("pattern")
nullSafeCodeGen(ctx, ev, (str, regex, limit) => {
// Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
s"""${ev.value} = new $arrayClass($str.split(
|CollationSupport.collationAwareRegex($regex, $collationId),$limit));""".stripMargin
s"""
|UTF8String $pattern =
| CollationSupport.collationAwareRegex($regex, $collationId, $legacySplitTruncate);
|${ev.value} = new $arrayClass($legacySplitTruncate ?
| $str.splitLegacyTruncate($pattern, $limit) : $str.split($pattern, $limit));
|""".stripMargin
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6158,6 +6158,21 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val LEGACY_TRUNCATE_FOR_EMPTY_REGEX_SPLIT =
buildConf("spark.sql.legacy.truncateForEmptyRegexSplit")
.internal()
.doc("When set to true, splitting a string of length n using an empty regex with a " +
"positive limit discards the last n - limit characters." +
"For example: SELECT split('abcd', '', 2) returns ['a', 'b']." +
"When set to false, the last element of the resulting array contains all input beyond " +
"the last matched regex." +
"For example: SELECT split('abcd', '', 2) returns ['a', 'bcd']." +
"According to the description of the split function, this should be set to false by " +
"default. See SPARK-49968 for details.")
.version("4.1.0")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,37 @@ class CollationRegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalH
}

test("StringSplit expression with collated strings") {
case class StringSplitTestCase[R](s: String, r: String, collation: String, expected: R)
case class StringSplitTestCase[R](s: String, r: String, collation: String,
expected: R, limit: Int)
val testCases = Seq(
StringSplitTestCase("1A2B3C", "[ABC]", "UTF8_BINARY", Seq("1", "2", "3", "")),
StringSplitTestCase("1A2B3C", "[abc]", "UTF8_BINARY", Seq("1A2B3C")),
StringSplitTestCase("1A2B3C", "[ABC]", "UTF8_LCASE", Seq("1", "2", "3", "")),
StringSplitTestCase("1A2B3C", "[abc]", "UTF8_LCASE", Seq("1", "2", "3", "")),
StringSplitTestCase("1A2B3C", "[1-9]+", "UTF8_BINARY", Seq("", "A", "B", "C")),
StringSplitTestCase("", "", "UTF8_BINARY", Seq("")),
StringSplitTestCase("1A2B3C", "", "UTF8_BINARY", Seq("1", "A", "2", "B", "3", "C")),
StringSplitTestCase("", "[1-9]+", "UTF8_BINARY", Seq("")),
StringSplitTestCase(null, "[1-9]+", "UTF8_BINARY", null),
StringSplitTestCase("1A2B3C", null, "UTF8_BINARY", null),
StringSplitTestCase(null, null, "UTF8_BINARY", null)
StringSplitTestCase("1A2B3C", "[ABC]", "UTF8_BINARY", Seq("1", "2", "3", ""), -1),
StringSplitTestCase("1A2B3C", "[abc]", "UTF8_BINARY", Seq("1A2B3C"), -1),
StringSplitTestCase("1A2B3C", "[ABC]", "UTF8_LCASE", Seq("1", "2", "3", ""), -1),
StringSplitTestCase("1A2B3C", "[abc]", "UTF8_LCASE", Seq("1", "2", "3", ""), -1),
StringSplitTestCase("1A2B3C", "[1-9]+", "UTF8_BINARY", Seq("", "A", "B", "C"), -1),
StringSplitTestCase("", "", "UTF8_BINARY", Seq(""), -1),
StringSplitTestCase("1A2B3C", "", "UTF8_BINARY", Seq("1", "A", "2", "B", "3", "C"), -1),
StringSplitTestCase("1A2B3C", "", "UTF8_LCASE", Seq("1", "A", "2", "B", "3", "C"), -1),
StringSplitTestCase("1A2B3C", "", "UTF8_BINARY", Seq("1", "A", "2", "B", "3", "C"), 0),
StringSplitTestCase("1A2B3C", "", "UTF8_LCASE", Seq("1", "A", "2", "B", "3", "C"), 0),
StringSplitTestCase("1A2B3C", "", "UTF8_BINARY", Seq("1A2B3C"), 1),
StringSplitTestCase("1A2B3C", "", "UTF8_LCASE", Seq("1A2B3C"), 1),
StringSplitTestCase("1A2B3C", "", "UTF8_BINARY", Seq("1", "A", "2B3C"), 3),
StringSplitTestCase("1A2B3C", "", "UTF8_LCASE", Seq("1", "A", "2B3C"), 3),
StringSplitTestCase("1A2B3C", "", "UTF8_BINARY", Seq("1", "A", "2", "B", "3", "C"), 6),
StringSplitTestCase("1A2B3C", "", "UTF8_LCASE", Seq("1", "A", "2", "B", "3", "C"), 6),
StringSplitTestCase("1A2B3C", "", "UTF8_BINARY", Seq("1", "A", "2", "B", "3", "C"), 100),
StringSplitTestCase("1A2B3C", "", "UTF8_LCASE", Seq("1", "A", "2", "B", "3", "C"), 100),
StringSplitTestCase("", "[1-9]+", "UTF8_BINARY", Seq(""), -1),
StringSplitTestCase(null, "[1-9]+", "UTF8_BINARY", null, -1),
StringSplitTestCase("1A2B3C", null, "UTF8_BINARY", null, -1),
StringSplitTestCase(null, null, "UTF8_BINARY", null, -1)
)
testCases.foreach(t => {
// StringSplit
checkEvaluation(StringSplit(
Literal.create(t.s, StringType(CollationFactory.collationNameToId(t.collation))),
Literal.create(t.r, StringType), -1), t.expected)
Literal.create(t.r, StringType), t.limit), t.expected)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
StringSplit(Literal("hello"), Literal(""), 5), Seq("h", "e", "l", "l", "o"), row1)
checkEvaluation(
StringSplit(Literal("hello"), Literal(""), 3), Seq("h", "e", "l"), row1)
StringSplit(Literal("hello"), Literal(""), 3), Seq("h", "e", "llo"), row1)
checkEvaluation(
StringSplit(Literal("hello"), Literal(""), 100), Seq("h", "e", "l", "l", "o"), row1)
checkEvaluation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,69 @@ Project [split(hello, , -1) AS split(hello, , -1)#x]
+- OneRowRelation


-- !query
SELECT split('hello', '', 0)
-- !query analysis
Project [split(hello, , 0) AS split(hello, , 0)#x]
+- OneRowRelation


-- !query
SELECT split('hello', '', 1)
-- !query analysis
Project [split(hello, , 1) AS split(hello, , 1)#x]
+- OneRowRelation


-- !query
SELECT split('hello', '', 3)
-- !query analysis
Project [split(hello, , 3) AS split(hello, , 3)#x]
+- OneRowRelation


-- !query
SELECT split('hello', '', 5)
-- !query analysis
Project [split(hello, , 5) AS split(hello, , 5)#x]
+- OneRowRelation


-- !query
SELECT split('hello', '', 100)
-- !query analysis
Project [split(hello, , 100) AS split(hello, , 100)#x]
+- OneRowRelation


-- !query
SELECT split('', '')
-- !query analysis
Project [split(, , -1) AS split(, , -1)#x]
+- OneRowRelation


-- !query
SELECT split('', '', -1)
-- !query analysis
Project [split(, , -1) AS split(, , -1)#x]
+- OneRowRelation


-- !query
SELECT split('', '', 0)
-- !query analysis
Project [split(, , 0) AS split(, , 0)#x]
+- OneRowRelation


-- !query
SELECT split('', '', 1)
-- !query analysis
Project [split(, , 1) AS split(, , 1)#x]
+- OneRowRelation


-- !query
SELECT split('abc', null)
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,69 @@ Project [split(hello, , -1) AS split(hello, , -1)#x]
+- OneRowRelation


-- !query
SELECT split('hello', '', 0)
-- !query analysis
Project [split(hello, , 0) AS split(hello, , 0)#x]
+- OneRowRelation


-- !query
SELECT split('hello', '', 1)
-- !query analysis
Project [split(hello, , 1) AS split(hello, , 1)#x]
+- OneRowRelation


-- !query
SELECT split('hello', '', 3)
-- !query analysis
Project [split(hello, , 3) AS split(hello, , 3)#x]
+- OneRowRelation


-- !query
SELECT split('hello', '', 5)
-- !query analysis
Project [split(hello, , 5) AS split(hello, , 5)#x]
+- OneRowRelation


-- !query
SELECT split('hello', '', 100)
-- !query analysis
Project [split(hello, , 100) AS split(hello, , 100)#x]
+- OneRowRelation


-- !query
SELECT split('', '')
-- !query analysis
Project [split(, , -1) AS split(, , -1)#x]
+- OneRowRelation


-- !query
SELECT split('', '', -1)
-- !query analysis
Project [split(, , -1) AS split(, , -1)#x]
+- OneRowRelation


-- !query
SELECT split('', '', 0)
-- !query analysis
Project [split(, , 0) AS split(, , 0)#x]
+- OneRowRelation


-- !query
SELECT split('', '', 1)
-- !query analysis
Project [split(, , 1) AS split(, , 1)#x]
+- OneRowRelation


-- !query
SELECT split('abc', null)
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@ select right("abcd", -2), right("abcd", 0), right("abcd", 'a');
SELECT split('aa1cc2ee3', '[1-9]+');
SELECT split('aa1cc2ee3', '[1-9]+', 2);
SELECT split('hello', '');
SELECT split('hello', '', 0);
SELECT split('hello', '', 1);
SELECT split('hello', '', 3);
SELECT split('hello', '', 5);
SELECT split('hello', '', 100);
SELECT split('', '');
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also prefer to see:

SELECT split('', '', -1);
SELECT split('', '', 0);
SELECT split('', '', 1);

here, for more complete testing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, already added.

SELECT split('', '', -1);
SELECT split('', '', 0);
SELECT split('', '', 1);
SELECT split('abc', null);
SELECT split(null, 'b');

Expand Down
Loading